Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
29eec2dd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29eec2dd
编写于
1月 04, 2023
作者:
L
lzy
提交者:
GitHub
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)
上级
a2d7e1d7
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
2776 addition
and
1241 deletion
+2776
-1241
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+50
-36
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+6
-6
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+2341
-1140
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+164
-20
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+213
-39
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
29eec2dd
...
@@ -31,6 +31,8 @@ namespace framework {
...
@@ -31,6 +31,8 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
static
const
std
::
unordered_set
<
std
::
string
>
FFN_ACTS
{
"relu"
,
"gelu"
};
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
...
@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
...
@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
...
@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
...
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
...
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
...
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
...
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
浏览文件 @
29eec2dd
...
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
...
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
...
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
...
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
29eec2dd
...
@@ -25,44 +25,20 @@ namespace framework {
...
@@ -25,44 +25,20 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
PDNode
*
FusedMultiTransformerEncoderPattern
::
operator
()()
{
static
const
std
::
unordered_set
<
std
::
string
>
FFN_ACTS
{
"relu"
,
"gelu"
};
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
PDNode
*
FusedMultiTransformerEncoderPattern
::
operator
()()
{
auto
*
layer_norm
=
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
())
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
3
)
{
if
(
x
->
outputs
.
size
()
==
4
)
{
return
true
;
return
true
;
}
else
{
}
else
{
return
false
;
return
false
;
}
}
});
});
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// Q path Nodes
// Q path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
...
@@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
->
assert_is_op_input
(
"scale"
);
auto
*
scale_q
=
pattern
->
NewNode
(
scale_q_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_q_out_var
=
pattern
->
NewNode
(
scale_q_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// Q path Links
// Q path Links
matmul0
->
LinksFrom
({
layer_norm_out_var
,
matmul0_w_var
})
matmul0
->
LinksFrom
({
input0
,
matmul0_w_var
}).
LinksTo
({
matmul0_out_var
});
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
scale_q
->
LinksFrom
({
transpose2_0_out_var
}).
LinksTo
({
scale_q_out_var
});
// K path Nodes
// K path Nodes
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
...
@@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsOutput
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
->
assert_is_op_input
(
"while"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
// K path Links
// K path Links
matmul1
->
LinksFrom
({
layer_norm_out_var
,
matmul1_w_var
})
matmul1
->
LinksFrom
({
input0
,
matmul1_w_var
}).
LinksTo
({
matmul1_out_var
});
.
LinksTo
({
matmul1_out_var
});
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
.
LinksTo
({
eltadd1_out_var
});
.
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
...
@@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsOutput
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
->
assert_is_op_input
(
"while"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
// V path Links
// V path Links
matmul2
->
LinksFrom
({
layer_norm_out_var
,
matmul2_w_var
})
matmul2
->
LinksFrom
({
input0
,
matmul2_w_var
}).
LinksTo
({
matmul2_out_var
});
.
LinksTo
({
matmul2_out_var
});
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
.
LinksTo
({
eltadd2_out_var
});
.
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// QK path Nodes
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul
_v2
"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
auto
*
eltadd_qk
=
...
@@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
transpose2_0
_out_var
,
transpose2_1_out_var
})
matmul_qk
->
LinksFrom
({
scale_q
_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
...
@@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// while loop
// post-LayerNorm
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
auto
*
layer_norm
=
while0
->
LinksFrom
({
transpose2_1_out_var
,
transpose2_2_out_var
});
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_
layer_norm_variance_var
=
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_
layer_norm_variance_repr
())
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_
layer_norm_out_repr
())
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
ffn_layer_norm
layer_norm
->
LinksFrom
(
->
LinksFrom
({
attention_output
,
layer_norm_bias_var
,
layer_norm_scale_var
})
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
(
.
LinksTo
({
ffn_layer_norm_out_var
,
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
...
@@ -353,11 +316,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -353,11 +316,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
->
AsIntermediate
()
->
assert_is_op_input
(
"layer_norm"
);
ffn_matmul0
->
LinksFrom
({
ffn_
layer_norm_out_var
,
ffn_matmul0_w_var
})
ffn_matmul0
->
LinksFrom
({
layer_norm_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_eltadd1_out_var
})
ffn_eltadd_out
->
LinksFrom
({
layer_norm_out_var
,
ffn_eltadd1_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
ffn_layer_norm
->
LinksFrom
(
{
ffn_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
return
ffn_layer_norm_out_var
;
}
}
PDNode
*
FusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
PDNode
*
FusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
...
@@ -649,11 +645,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -649,11 +645,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
...
@@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
return
ffn_output
;
return
ffn_output
;
}
}
PDNode
*
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
PDNode
*
MultiDevicesFusedMultiTransformerEncoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
())
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
->
assert_is_op_input
(
"c_identity"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
// pre-LayerNorm
->
assert_more
([](
Node
*
x
)
{
auto
*
layer_norm
=
if
(
x
->
outputs
.
size
()
==
4
)
{
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
return
true
;
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
}
else
{
->
AsInput
()
return
false
;
->
assert_is_persistable_var
()
}
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
});
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// communication c_identity
// communication c_identity
auto
*
c_identity
=
auto
*
c_identity
0
=
pattern
->
NewNode
(
c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
pattern
->
NewNode
(
c_identity
0
_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity
_out_var
=
pattern
->
NewNode
(
c_identity
_out_repr
())
auto
*
c_identity
0_out_var
=
pattern
->
NewNode
(
c_identity0
_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
c_identity
->
LinksFrom
({
layer_norm_out_var
}).
LinksTo
({
c_identity_out_var
});
auto
*
c_identity1
=
pattern
->
NewNode
(
c_identity1_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity1_out_var
=
pattern
->
NewNode
(
c_identity1_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
c_identity2
=
pattern
->
NewNode
(
c_identity2_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity2_out_var
=
pattern
->
NewNode
(
c_identity2_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
c_identity0
->
LinksFrom
({
input0
}).
LinksTo
({
c_identity0_out_var
});
c_identity1
->
LinksFrom
({
input0
}).
LinksTo
({
c_identity1_out_var
});
c_identity2
->
LinksFrom
({
input0
}).
LinksTo
({
c_identity2_out_var
});
// Q
KV fused
path Nodes
// Q path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
AsInput
()
...
@@ -771,75 +761,137 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -771,75 +761,137 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
->
assert_is_op_input
(
"scale"
);
auto
*
scale_q
=
pattern
->
NewNode
(
scale_q_repr
())
->
assert_is_op
(
"scale"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
scale_q_out_var
=
pattern
->
NewNode
(
scale_q_out_repr
())
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
// Q
KV fused
path Links
// Q path Links
matmul0
->
LinksFrom
({
c_identity_out_var
,
matmul0_w_var
})
matmul0
->
LinksFrom
({
c_identity
0
_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
scale_q
->
LinksFrom
({
transpose2_0_out_var
}).
LinksTo
({
scale_q_out_var
});
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
// K path Nodes
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
->
assert_is_op_output
(
"scale"
)
auto
*
matmul1_w_var
=
pattern
->
NewNode
(
matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul1_out_var
=
pattern
->
NewNode
(
matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X
"
);
->
assert_is_op_input
(
"elementwise_add
"
);
auto
*
eltadd
_qk
=
auto
*
eltadd
1
=
pattern
->
NewNode
(
eltadd
_qk
_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd
1
_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd
_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk
_b_repr
())
auto
*
eltadd
1_b_var
=
pattern
->
NewNode
(
eltadd1
_b_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
auto
*
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax
"
);
->
assert_is_op_input
(
"reshape2
"
);
auto
*
softmax_qk
=
auto
*
reshape2_1
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax
"
);
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2
"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk
_out_repr
())
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1
_out_repr
())
->
assert_is_op_output
(
"
softmax
"
)
->
assert_is_op_output
(
"
reshape2
"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"
matmul_v2"
,
"X
"
);
->
assert_is_op_input
(
"
transpose2
"
);
// QK path Linsk
auto
*
transpose2_1
=
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
.
LinksTo
({
matmul_qk_out_var
});
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
->
assert_is_op_output
(
"transpose2"
)
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
->
AsIntermediate
()
.
LinksTo
({
eltadd_qk_out_var
});
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// K path Links
matmul1
->
LinksFrom
({
c_identity1_out_var
,
matmul1_w_var
})
.
LinksTo
({
matmul1_out_var
});
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
.
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// V path Nodes
auto
*
matmul2
=
pattern
->
NewNode
(
matmul2_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul2_w_var
=
pattern
->
NewNode
(
matmul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul2_out_var
=
pattern
->
NewNode
(
matmul2_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
// V path Links
matmul2
->
LinksFrom
({
c_identity2_out_var
,
matmul2_w_var
})
.
LinksTo
({
matmul2_out_var
});
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
.
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// QK path Linsk
matmul_qk
->
LinksFrom
({
scale_q_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
split0_v
_out_var
})
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
transpose2_2
_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// post-LayerNorm
auto
*
ffn_layer_norm
=
auto
*
layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_
layer_norm_variance_var
=
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_
layer_norm_variance_repr
())
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
As
Intermediate
()
->
As
Output
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_
layer_norm_out_repr
())
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
->
assert_is_op_input
(
"c_identity"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
ffn_layer_norm
layer_norm
->
LinksFrom
(
->
LinksFrom
({
attention_output
,
layer_norm_bias_var
,
layer_norm_scale_var
})
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
(
.
LinksTo
({
ffn_layer_norm_out_var
,
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// communication c_identity
// communication c_identity
auto
*
ffn_c_identity
=
auto
*
ffn_c_identity
=
...
@@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_c_identity
->
LinksFrom
({
ffn_
layer_norm_out_var
})
ffn_c_identity
->
LinksFrom
({
layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2
// Feed Forward fc1 -> gelu -> fc2
...
@@ -974,11 +1029,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -974,11 +1029,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
->
AsIntermediate
()
->
assert_is_op_input
(
"layer_norm"
);
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_eltadd1_out_var
})
ffn_eltadd_out
->
LinksFrom
({
layer_norm_out_var
,
ffn_eltadd1_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
// Feed Forward LayerNorm Nodes
}
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
}
// namespace patterns
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
template
<
typename
T
>
ffn_layer_norm
inline
void
QKVWeightsProcess
(
phi
::
DenseTensor
*
wq_tensor
,
->
LinksFrom
(
phi
::
DenseTensor
*
wk_tensor
,
{
ffn_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
phi
::
DenseTensor
*
wv_tensor
,
.
LinksTo
({
ffn_layer_norm_out_var
,
const
int
num_head
,
ffn_layer_norm_mean_var
,
const
int
dim_head
,
ffn_layer_norm_variance_var
});
const
int
dim_embed
)
{
auto
*
wq_data
=
wq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
wk_data
=
wk_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
wv_data
=
wv_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
combined_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
return
ffn_layer_norm_out_var
;
}
phi
::
DenseTensor
tmp_combined_w_tensor
;
PDNode
*
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
tmp_combined_w_data
=
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
tmp_combined_w_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
std
::
vector
<
T
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
// pre-LayerNorm
// Combine the three fc weights together.
auto
*
layer_norm
=
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
->
AsInput
()
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
->
assert_is_persistable_var
()
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
int
in_idx
=
l
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
->
AsInput
()
tmp_combined_w_data
[
out_idx
]
=
w_vec
[
i
][
in_idx
];
->
assert_is_persistable_var
()
}
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
}
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
}
->
AsOutput
()
}
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
wq_tensor
->
Resize
(
combined_w_dims
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
.
LinksTo
(
memcpy
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
T
)
*
wq_tensor
->
numel
());
}
template
<
typename
T
>
// communication c_identity
inline
void
QKVBiasProcess
(
phi
::
DenseTensor
*
bq_tensor
,
auto
*
c_identity
=
phi
::
DenseTensor
*
bk_tensor
,
pattern
->
NewNode
(
c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
phi
::
DenseTensor
*
bv_tensor
,
auto
*
c_identity_out_var
=
pattern
->
NewNode
(
c_identity_out_repr
())
const
int
num_head
,
->
AsIntermediate
()
const
int
dim_head
,
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
const
int
dim_embed
)
{
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
bq_data
=
bq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
c_identity
->
LinksFrom
({
layer_norm_out_var
}).
LinksTo
({
c_identity_out_var
});
auto
*
bk_data
=
bk_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
bv_data
=
bv_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
// QKV fused path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
phi
::
DenseTensor
tmp_combined_bias_tensor
;
auto
*
eltadd0
=
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
tmp_combined_bias_data
=
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
tmp_combined_bias_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
size_t
bias_size
=
bq_tensor
->
numel
();
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
// QKV fused path Links
matmul0
->
LinksFrom
({
c_identity_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
c_allreduce_sum
=
pattern
->
NewNode
(
c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
c_allreduce_sum_out_var
=
pattern
->
NewNode
(
c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
split0_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
c_allreduce_sum
->
LinksFrom
({
matmul_linear_out_var
})
.
LinksTo
({
c_allreduce_sum_out_var
});
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// communication c_identity
auto
*
ffn_c_identity
=
pattern
->
NewNode
(
ffn_c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
ffn_c_identity_out_var
=
pattern
->
NewNode
(
ffn_c_identity_out_repr
())
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_ops_input
(
FFN_ACTS
);
auto
*
ffn_act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_act_out_var
=
pattern
->
NewNode
(
ffn_act_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
ffn_c_allreduce_sum
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
ffn_c_allreduce_sum_out_var
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_act_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_eltadd1_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
}
// namespace patterns
template
<
typename
T
>
inline
void
QKVWeightsProcess
(
phi
::
DenseTensor
*
wq_tensor
,
phi
::
DenseTensor
*
wk_tensor
,
phi
::
DenseTensor
*
wv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
wq_data
=
wq_tensor
->
data
<
T
>
();
auto
*
wk_data
=
wk_tensor
->
data
<
T
>
();
auto
*
wv_data
=
wv_tensor
->
data
<
T
>
();
auto
combined_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
phi
::
DenseTensor
tmp_combined_w_tensor
;
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_combined_w_tensor
);
auto
*
tmp_combined_w_data
=
tmp_combined_w_tensor
.
data
<
T
>
();
std
::
vector
<
T
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
// Combine the three fc weights together.
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
tmp_combined_w_data
[
out_idx
]
=
w_vec
[
i
][
in_idx
];
}
}
}
}
wq_tensor
->
Resize
(
combined_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
wq_tensor
);
auto
*
new_combined_w_data
=
wq_tensor
->
data
<
T
>
();
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
T
)
*
wq_tensor
->
numel
());
}
template
<
typename
T
>
inline
void
QKVBiasProcess
(
phi
::
DenseTensor
*
bq_tensor
,
phi
::
DenseTensor
*
bk_tensor
,
phi
::
DenseTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
bq_data
=
bq_tensor
->
data
<
T
>
();
auto
*
bk_data
=
bk_tensor
->
data
<
T
>
();
auto
*
bv_data
=
bv_tensor
->
data
<
T
>
();
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
phi
::
DenseTensor
tmp_combined_bias_tensor
;
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_combined_bias_tensor
);
auto
*
tmp_combined_bias_data
=
tmp_combined_bias_tensor
.
data
<
T
>
();
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
T
)
*
bias_size
);
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
T
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
dev_ctx
->
Alloc
<
T
>
(
bq_tensor
);
bq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
new_combined_bias_data
=
bq_tensor
->
data
<
T
>
();
memcpy
(
new_combined_bias_data
,
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
T
)
*
bq_tensor
->
numel
());
sizeof
(
T
)
*
bq_tensor
->
numel
());
}
}
inline
void
QKVWeightsBiasProcess
(
phi
::
DenseTensor
*
wq_tensor
,
phi
::
DenseTensor
*
wk_tensor
,
phi
::
DenseTensor
*
wv_tensor
,
phi
::
DenseTensor
*
bq_tensor
,
phi
::
DenseTensor
*
bk_tensor
,
phi
::
DenseTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
switch
(
wq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVWeightsProcess
<
platform
::
float16
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcess
<
float
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcess
<
int8_t
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
bq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcess
<
platform
::
float16
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcess
<
float
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
template
<
typename
T
>
inline
void
QKVWeightsProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
qkv_w_data
=
qkv_w_tensor
->
data
<
T
>
();
auto
transpose_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
phi
::
DenseTensor
tmp_transpose_w_tensor
;
tmp_transpose_w_tensor
.
Resize
(
transpose_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_transpose_w_tensor
);
auto
*
tmp_transpose_w_data
=
tmp_transpose_w_tensor
.
data
<
T
>
();
// transpose qkv matmul Y to QKVWeights
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
3
*
dim_head
+
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_w_data
[
out_idx
]
=
qkv_w_data
[
in_idx
];
}
}
}
}
qkv_w_tensor
->
Resize
(
transpose_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
qkv_w_tensor
);
auto
*
new_transpose_w_data
=
qkv_w_tensor
->
data
<
T
>
();
memcpy
(
new_transpose_w_data
,
tmp_transpose_w_data
,
sizeof
(
T
)
*
qkv_w_tensor
->
numel
());
}
template
<
typename
T
>
inline
void
QKVBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_b_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
qkv_b_data
=
qkv_b_tensor
->
data
<
T
>
();
auto
transpose_b_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
phi
::
DenseTensor
tmp_transpose_b_tensor
;
tmp_transpose_b_tensor
.
Resize
(
transpose_b_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_transpose_b_tensor
);
auto
*
tmp_transpose_b_data
=
tmp_transpose_b_tensor
.
data
<
T
>
();
// transpose qkv elemenwise_add Y to QKVBias
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
int
in_idx
=
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_b_data
[
out_idx
]
=
qkv_b_data
[
in_idx
];
}
}
}
qkv_b_tensor
->
Resize
({
3
,
num_head
,
dim_head
});
dev_ctx
->
Alloc
<
T
>
(
qkv_b_tensor
);
auto
*
new_transpose_b_data
=
qkv_b_tensor
->
data
<
T
>
();
memcpy
(
new_transpose_b_data
,
tmp_transpose_b_data
,
sizeof
(
T
)
*
qkv_b_tensor
->
numel
());
}
inline
void
QKVWeightsBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
phi
::
DenseTensor
*
qkv_b_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
switch
(
qkv_w_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVWeightsProcessFuseQKV
<
platform
::
float16
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcessFuseQKV
<
float
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcessFuseQKV
<
int8_t
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
qkv_b_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcessFuseQKV
<
platform
::
float16
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcessFuseQKV
<
float
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
// Just use for fused_multi_transformer_int8
inline
void
TransposeWeights
(
phi
::
DenseTensor
*
weight_tensor
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
int
m
=
weight_tensor
->
dims
()[
0
];
int
n
=
weight_tensor
->
dims
()[
1
];
phi
::
DenseTensor
tmp_weight_tensor
;
tmp_weight_tensor
.
Resize
({
n
,
m
});
dev_ctx
->
Alloc
<
int8_t
>
(
&
tmp_weight_tensor
);
auto
tmp_weight_data
=
tmp_weight_tensor
.
data
<
int8_t
>
();
auto
weight_data
=
weight_tensor
->
data
<
int8_t
>
();
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
int
in_idx
=
i
*
n
+
j
;
int
out_idx
=
j
*
m
+
i
;
tmp_weight_data
[
out_idx
]
=
weight_data
[
in_idx
];
}
}
weight_tensor
->
Resize
({
n
,
m
});
dev_ctx
->
Alloc
<
int8_t
>
(
weight_tensor
);
auto
new_weight_data
=
weight_tensor
->
data
<
int8_t
>
();
memcpy
(
new_weight_data
,
tmp_weight_data
,
sizeof
(
int8_t
)
*
m
*
n
);
}
inline
Node
*
CreatePersistableVarNode
(
Graph
*
graph
,
const
std
::
string
&
name
)
{
auto
var_desc
=
VarDesc
(
name
);
var_desc
.
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
var_desc
.
SetPersistable
(
true
);
auto
node
=
graph
->
CreateVarNode
(
&
var_desc
);
return
node
;
}
int
FusedMultiTransformerEncoderPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
if
(
enable_int8
)
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with int8"
;
}
else
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with fp"
;
}
// Create pattern.
patterns
::
FusedMultiTransformerEncoderPattern
fused_multi_transformer_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
matmul0
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
reshape2_0
,
Node
*
matmul_linear
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_layer_norm_out
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
wq_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wk_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq_tensor.shape[1] and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
wq_tensor
->
dims
()[
1
]
/
dim_head
;
QKVWeightsBiasProcess
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
if
(
enable_int8
)
{
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
ffn0_w_tensor
=
scope
->
FindVar
(
ffn_matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
ffn1_w_tensor
=
scope
->
FindVar
(
ffn_matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
TransposeWeights
(
out_linear_w_tensor
);
TransposeWeights
(
ffn0_w_tensor
);
TransposeWeights
(
ffn1_w_tensor
);
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
combined_bias_desc
->
SetPersistable
(
true
);
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
?
"fused_multi_transformer_int8"
:
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// CacheKV input
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
bq_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
OpDesc
fill_const_op_desc
(
layer_norm
->
Op
()
->
Block
());
fill_const_op_desc
.
SetType
(
"fill_constant_batch_size_like"
);
fill_const_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
fill_const_op_desc
.
SetOutput
(
"Out"
,
{
cache_kv
->
Name
()});
std
::
vector
<
int
>
shape
=
{
2
,
-
1
,
num_head
,
1024
,
dim_head
};
fill_const_op_desc
.
SetAttr
(
"shape"
,
shape
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
bq_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_layer_norm_out
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
false
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
{
ffn_act
->
Op
()
->
Type
()});
// Quantization attribute/Input
if
(
enable_int8
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
// Set input scale
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"Input_scale_"
+
qkv_input_name
));
std
::
string
out_linear_input_name
=
matmul_linear_op
->
Input
(
"X"
)[
0
];
auto
out_linear_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"Input_scale_"
+
out_linear_input_name
));
std
::
string
ffn0_input_name
=
ffn_matmul_0_op
->
Input
(
"X"
)[
0
];
auto
ffn0_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"Input_scale_"
+
ffn0_input_name
));
std
::
string
ffn1_input_name
=
ffn_matmul_1_op
->
Input
(
"X"
)[
0
];
auto
ffn1_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
// Calc outscale and Set them
auto
qkv_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
auto
out_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"weight_scale"
));
auto
ffn0_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"weight_scale"
));
auto
ffn1_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"weight_scale"
));
auto
qkv_out_scales
=
std
::
vector
<
float
>
(
3
*
dim_embed
,
(
qkv_weight_scale
/
127.0
f
)
*
(
qkv_in_scale
/
127.0
f
));
auto
out_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
out_weight_scale
/
127.0
f
)
*
(
out_linear_in_scale
/
127.0
f
));
auto
ffn0_out_scales
=
std
::
vector
<
float
>
(
4
*
dim_embed
,
(
ffn0_weight_scale
/
127.0
f
)
*
(
ffn0_in_scale
/
127.0
f
));
auto
ffn1_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
ffn1_weight_scale
/
127.0
f
)
*
(
ffn1_in_scale
/
127.0
f
));
// Inverse input scale
qkv_in_scale
=
1.0
f
/
qkv_in_scale
;
out_linear_in_scale
=
1.0
f
/
out_linear_in_scale
;
ffn0_in_scale
=
1.0
f
/
ffn0_in_scale
;
ffn1_in_scale
=
1.0
f
/
ffn1_in_scale
;
fused_multi_transformer_op_desc
.
SetAttr
(
"qkv_in_scale"
,
std
::
vector
<
float
>
{
qkv_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"out_linear_in_scale"
,
std
::
vector
<
float
>
{
out_linear_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn1_in_scale"
,
std
::
vector
<
float
>
{
ffn0_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn2_in_scale"
,
std
::
vector
<
float
>
{
ffn1_in_scale
});
auto
qkv_out_scale_var
=
scope
->
Var
(
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_var
=
scope
->
Var
(
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_var
=
scope
->
Var
(
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
*
qkv_out_scale_tensor
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
qkv_out_scale_tensor
->
Resize
({
3
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
qkv_out_scale_tensor
);
auto
qkv_out_scale_data
=
qkv_out_scale_tensor
->
data
<
float
>
();
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
*
out_out_scale_tensor
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
out_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
out_out_scale_tensor
);
auto
out_out_scale_data
=
out_out_scale_tensor
->
data
<
float
>
();
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
*
ffn0_out_scale_tensor
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
ffn0_out_scale_tensor
->
Resize
({
4
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn0_out_scale_tensor
);
auto
ffn0_out_scale_data
=
ffn0_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
*
ffn1_out_scale_tensor
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
ffn1_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn1_out_scale_tensor
);
auto
ffn1_out_scale_data
=
ffn1_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2OutScale"
,
{
ffn_matmul1_w
->
Name
()
+
"_out_scale"
});
}
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
if
(
enable_int8
)
{
auto
qkv_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
IR_NODE_LINK_TO
(
qkv_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
out_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn0_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn1_out_scale_node
,
fused_multi_transformer
);
}
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fill_const_op
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_layer_norm_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder pass in "
"op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1
,
matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_w
,
matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_q
,
scale_q
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_q_out
,
scale_q_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_act
,
ffn_act
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_act_out
,
ffn_act_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
matmul0
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
reshape2_0
,
matmul_linear
,
matmul_linear_w
,
eltadd_linear_b
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0
,
ffn_matmul0_w
,
ffn_matmul1
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_layer_norm_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
scale_q
,
scale_q_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_act
,
ffn_act_out
,
ffn_output
,
ffn_eltadd_out
});
inline
void
QKVWeightsBiasProcess
(
phi
::
DenseTensor
*
wq_tensor
,
// Remove unneeded nodes.
phi
::
DenseTensor
*
wk_tensor
,
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
phi
::
DenseTensor
*
wv_tensor
,
++
fusion_count
;
phi
::
DenseTensor
*
bq_tensor
,
};
phi
::
DenseTensor
*
bk_tensor
,
gpd
(
graph
,
handler
);
phi
::
DenseTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
switch
(
wq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVWeightsProcess
<
platform
::
float16
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcess
<
float
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcess
<
int8_t
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
bq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcess
<
platform
::
float16
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcess
<
float
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
template
<
typename
T
>
return
fusion_count
;
inline
void
QKVWeightsProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
}
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
qkv_w_data
=
qkv_w_tensor
->
data
<
T
>
();
auto
transpose_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
phi
::
DenseTensor
tmp_transpose_w_tensor
;
void
FusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
tmp_transpose_w_tensor
.
Resize
(
transpose_w_dims
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
tmp_transpose_w_data
=
auto
*
scope
=
param_scope
();
tmp_transpose_w_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multi_transformer pass, The scope should not be null."
));
// transpose qkv matmul Y to QKVWeights
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
if
(
fusion_count
>
0
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
3
*
dim_head
+
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_w_data
[
out_idx
]
=
qkv_w_data
[
in_idx
];
}
}
}
}
}
AddStatis
(
fusion_count
);
qkv_w_tensor
->
Resize
(
transpose_w_dims
);
auto
*
new_transpose_w_data
=
qkv_w_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
memcpy
(
new_transpose_w_data
,
tmp_transpose_w_data
,
sizeof
(
T
)
*
qkv_w_tensor
->
numel
());
}
}
template
<
typename
T
>
FusedMultiTransformerEncoderPass
::
FusedMultiTransformerEncoderPass
()
{
inline
void
QKVBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_b_tensor
,
AddOpCompat
(
OpCompat
(
"layer_norm"
))
const
int
num_head
,
.
AddInput
(
"X"
)
const
int
dim_head
,
.
IsTensor
()
const
int
dim_embed
)
{
.
End
()
auto
*
qkv_b_data
=
qkv_b_tensor
->
data
<
T
>
();
.
AddInput
(
"Scale"
)
auto
transpose_b_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
phi
::
DenseTensor
tmp_transpose_b_tensor
;
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
tmp_transpose_b_tensor
.
Resize
(
transpose_b_dims
);
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
auto
*
tmp_transpose_b_data
=
.
IsTensor
()
tmp_transpose_b_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
// transpose qkv elemenwise_add Y to QKVBias
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
.
AddInput
(
"X"
)
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
.
IsTensor
()
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
.
End
()
int
out_idx
=
i
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
.
AddInput
(
"Y"
)
int
in_idx
=
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
.
IsTensor
()
tmp_transpose_b_data
[
out_idx
]
=
qkv_b_data
[
in_idx
];
.
End
()
}
.
AddOutput
(
"Out"
)
}
.
IsTensor
()
}
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
qkv_b_tensor
->
Resize
({
3
,
num_head
,
dim_head
});
AddOpCompat
(
OpCompat
(
"scale"
))
auto
*
new_transpose_b_data
=
.
AddInput
(
"X"
)
qkv_b_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
.
IsTensor
()
memcpy
(
new_transpose_b_data
,
.
End
()
tmp_transpose_b_data
,
.
AddOutput
(
"Out"
)
sizeof
(
T
)
*
qkv_b_tensor
->
numel
());
.
IsTensor
()
}
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
inline
void
QKVWeightsBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
AddOpCompat
(
OpCompat
(
"softmax"
))
phi
::
DenseTensor
*
qkv_b_tensor
,
.
AddInput
(
"X"
)
const
int
num_head
,
.
IsTensor
()
const
int
dim_head
,
.
End
()
const
int
dim_embed
)
{
.
AddOutput
(
"Out"
)
switch
(
qkv_w_tensor
->
dtype
())
{
.
IsTensor
()
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
.
End
()
QKVWeightsProcessFuseQKV
<
platform
::
float16
>
(
.
AddAttr
(
"axis"
)
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
break
;
.
End
();
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcessFuseQKV
<
float
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcessFuseQKV
<
int8_t
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
qkv_b_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcessFuseQKV
<
platform
::
float16
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcessFuseQKV
<
float
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
// Just use for fused_multi_transformer_int8
AddOpCompat
(
OpCompat
(
"gelu"
))
inline
void
TransposeWeights
(
phi
::
DenseTensor
*
weight_tensor
)
{
.
AddInput
(
"X"
)
int
m
=
weight_tensor
->
dims
()[
0
];
.
IsTensor
()
int
n
=
weight_tensor
->
dims
()[
1
];
.
End
()
phi
::
DenseTensor
tmp_weight_tensor
;
.
AddOutput
(
"Out"
)
auto
tmp_weight_data
=
.
IsTensor
()
tmp_weight_tensor
.
mutable_data
<
int8_t
>
({
n
,
m
},
platform
::
CPUPlace
());
.
End
()
auto
weight_data
=
weight_tensor
->
data
<
int8_t
>
();
.
AddAttr
(
"approximate"
)
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
.
IsType
<
bool
>
()
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
.
End
();
int
in_idx
=
i
*
n
+
j
;
int
out_idx
=
j
*
m
+
i
;
tmp_weight_data
[
out_idx
]
=
weight_data
[
in_idx
];
}
}
weight_tensor
->
Resize
({
n
,
m
});
auto
new_weight_data
=
weight_tensor
->
mutable_data
<
int8_t
>
(
platform
::
CPUPlace
());
memcpy
(
new_weight_data
,
tmp_weight_data
,
sizeof
(
int8_t
)
*
m
*
n
);
}
inline
Node
*
CreatePersistableVarNode
(
Graph
*
graph
,
const
std
::
string
&
name
)
{
AddOpCompat
(
OpCompat
(
"relu"
))
auto
var_desc
=
VarDesc
(
name
);
.
AddInput
(
"X"
)
var_desc
.
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
.
IsTensor
()
var_desc
.
SetPersistable
(
true
);
.
End
()
auto
node
=
graph
->
CreateVarNode
(
&
var_desc
);
.
AddOutput
(
"Out"
)
return
node
;
.
IsTensor
()
.
End
();
}
}
int
FusedMultiTransformerEncoderPass
::
BuildFusion
(
Graph
*
graph
,
int
FusedMultiTransformerEncoderFuseQKVPass
::
BuildFusion
(
const
std
::
string
&
name_scope
,
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
if
(
enable_int8
)
{
if
(
enable_int8
)
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with int8"
;
VLOG
(
3
)
<<
"FusedMultiTransformerEncoder
FuseQKV
Pass with int8"
;
}
else
{
}
else
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with fp"
;
VLOG
(
3
)
<<
"FusedMultiTransformerEncoder
FuseQKV
Pass with fp"
;
}
}
// Create pattern.
// Create pattern.
patterns
::
FusedMultiTransformerEncoder
Pattern
fused_multi_transformer_pattern
(
patterns
::
FusedMultiTransformerEncoder
FuseQKVPattern
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_pattern
();
fused_multi_transformer_
fuse_qkv_
pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
...
@@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
layer_norm_variance
,
Node
*
layer_norm_variance
,
Node
*
matmul0
,
Node
*
matmul0
,
Node
*
matmul0_w
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
split0_k_out
,
Node
*
eltadd2_b
,
Node
*
split0_v_out
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear
,
Node
*
matmul_linear
,
...
@@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
wq
_tensor
=
auto
*
qkv_w
_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wk_tensor
=
auto
*
qkv_b_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq
_tensor.shape[1]
and dim_head
// 2. calculate num_head according to wq
kv_tensor.shape[1]/3
and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
layer_norm_bias_tensor
=
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
wq_tensor
->
dims
()[
1
]
/
dim_head
;
int
num_head
=
qkv_w_tensor
->
dims
()[
1
]
/
3
/
dim_head
;
QKVWeightsBiasProcess
(
wq_tensor
,
QKVWeightsBiasProcessFuseQKV
(
wk_tensor
,
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
if
(
enable_int8
)
{
if
(
enable_int8
)
{
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
...
@@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
TransposeWeights
(
ffn1_w_tensor
);
TransposeWeights
(
ffn1_w_tensor
);
}
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
combined_bias_desc
->
SetPersistable
(
true
);
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
...
@@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
bq
_tensor
->
dtype
()));
framework
::
TransToProtoVarType
(
qkv_b
_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
...
@@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
fill_const_op_desc
.
SetAttr
(
"dtype"
,
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
bq
_tensor
->
dtype
())));
qkv_b
_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
// Quantization attribute/Input
// Quantization attribute/Input
if
(
enable_int8
)
{
if
(
enable_int8
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
// Set input scale
// Set input scale
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
...
@@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
// Calc outscale and Set them
// Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto
qkv_weight_scale
=
auto
qkv_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
auto
out_weight_scale
=
auto
out_weight_scale
=
...
@@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
ffn1_out_scale_var
=
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
qkv_out_scale_data
=
auto
*
qkv_out_scale_tensor
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
3
*
dim_embed
},
platform
::
CPUPlace
());
qkv_out_scale_tensor
->
Resize
({
3
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
qkv_out_scale_tensor
);
auto
qkv_out_scale_data
=
qkv_out_scale_tensor
->
data
<
float
>
();
memcpy
(
qkv_out_scale_data
,
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
out_out_scale_data
=
auto
*
out_out_scale_tensor
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
out_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
out_out_scale_tensor
);
auto
out_out_scale_data
=
out_out_scale_tensor
->
data
<
float
>
();
memcpy
(
out_out_scale_data
,
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
ffn0_out_scale_data
=
auto
*
ffn0_out_scale_tensor
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
4
*
dim_embed
},
platform
::
CPUPlace
());
ffn0_out_scale_tensor
->
Resize
({
4
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn0_out_scale_tensor
);
auto
ffn0_out_scale_data
=
ffn0_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn0_out_scale_data
,
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
ffn1_out_scale_data
=
auto
*
ffn1_out_scale_tensor
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
ffn1_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn1_out_scale_tensor
);
auto
ffn1_out_scale_data
=
ffn1_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn1_out_scale_data
,
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
...
@@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
while_Xs
.
erase
(
std
::
remove
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
transpose2_1_out
->
Name
()),
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
transpose2_2_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
matmul1_w
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
matmul2_w
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
eltadd1_b
->
Name
()),
std
::
end
(
while_Xs
));
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
while_Xs
.
erase
(
std
::
remove
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
eltadd2_b
->
Name
()),
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Xs
));
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
...
@@ -1670,13 +2901,13 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1670,13 +2901,13 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// 1. delete k, v
// 1. delete k, v
// 2. add cache_kv
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
while_Outs
.
erase
(
std
::
end
(
while_Outs
),
std
::
remove
(
transpose2_1
_out
->
Name
()),
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_k
_out
->
Name
()),
std
::
end
(
while_Outs
));
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
while_Outs
.
erase
(
std
::
end
(
while_Outs
),
std
::
remove
(
transpose2_2
_out
->
Name
()),
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_v
_out
->
Name
()),
std
::
end
(
while_Outs
));
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
...
@@ -1684,213 +2915,200 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1684,213 +2915,200 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// link CacheKV to while
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
// unlink origin KV output to while
IR_NODE_UNLINK
(
transpose2_1_out
,
while0
);
IR_NODE_UNLINK
(
split0_k_out
,
while0
);
IR_NODE_UNLINK
(
transpose2_2_out
,
while0
);
IR_NODE_UNLINK
(
split0_v_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
transpose2_1_out
);
IR_NODE_UNLINK
(
while0
,
split0_k_out
);
IR_NODE_UNLINK
(
while0
,
transpose2_2_out
);
IR_NODE_UNLINK
(
while0
,
split0_v_out
);
// unlink KV weight/bias to while after merged into Q weight/bias
IR_NODE_UNLINK
(
matmul1_w
,
while0
);
IR_NODE_UNLINK
(
matmul2_w
,
while0
);
IR_NODE_UNLINK
(
eltadd1_b
,
while0
);
IR_NODE_UNLINK
(
eltadd2_b
,
while0
);
};
};
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
pass in
"
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
_fuse_qkv
"
"op compat failed."
;
"
pass in
op compat failed."
;
return
;
return
;
}
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
VLOG
(
4
)
<<
"handle MultiTransformer encoder(Fuse-QKV) fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1
,
matmul1
,
fused_multi_transformer_pattern
);
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
fused_multi_transformer_pattern
);
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul
1_w
,
matmul1_w
,
fused_multi_transformer
_pattern
);
matmul
0
,
matmul0
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
fused_multi_transformer
_pattern
);
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
fused_multi_transformer
_pattern
);
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
fused_multi_transformer_pattern
);
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
fused_multi_transformer_pattern
);
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
fused_multi_transformer_pattern
);
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
fused_multi_transformer
_pattern
);
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
fused_multi_transformer_pattern
);
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
fused_multi_transformer_pattern
)
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_pattern
);
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_pattern
);
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_pattern
);
ffn_matmul0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
eltadd0
,
ffn_eltadd0
,
fused_multi_transformer
_pattern
);
ffn_
matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
_b
,
ffn_eltadd0_b
,
fused_multi_transformer
_pattern
);
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer
_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu_out
,
fused_multi_transformer
_pattern
);
ffn_
act_out
,
ffn_act_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_pattern
);
ffn_matmul1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
eltadd1
,
ffn_eltadd1
,
fused_multi_transformer
_pattern
);
ffn_
matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
_b
,
ffn_eltadd1_b
,
fused_multi_transformer
_pattern
);
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_pattern
)
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_pattern
);
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd
1_b
,
eltadd1_b
,
fused_multi_transformer
_pattern
);
eltadd
0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd
1_out
,
eltadd1_out
,
fused_multi_transformer
_pattern
);
eltadd
0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
fused_multi_transformer_pattern
);
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
fused_multi_transformer
_pattern
);
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer
_pattern
);
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer
_pattern
);
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_pattern
);
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_pattern
);
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_pattern
);
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_pattern
);
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_pattern
);
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_pattern
)
matmul_linear
,
matmul_linear
,
fused_multi_transformer_
fuse_qkv_
pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_pattern
)
matmul_linear_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
)
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer
_pattern
)
fused_multi_transformer_fuse_qkv
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer
_pattern
)
while0
,
while0
,
fused_multi_transformer_fuse_qkv
_pattern
)
fuse_creater
(
input0
,
fuse_creater
(
input0
,
layer_norm
,
layer_norm
,
...
@@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_variance
,
layer_norm_variance
,
matmul0
,
matmul0
,
matmul0_w
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd0_b
,
eltadd1_b
,
split0_k_out
,
eltadd2_b
,
split0_v_out
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_0
,
matmul_linear
,
matmul_linear
,
...
@@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
matmul0
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_0_out
,
transpose2_1_out
,
split0
,
transpose2_2_out
,
split0_q_out
,
split0_k_out
,
split0_v_out
,
matmul_qk
,
matmul_qk
,
matmul_qk_out
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk
,
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
...
@@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
return
fusion_count
;
return
fusion_count
;
}
}
void
FusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
FusedMultiTransformerEncoder
FuseQKV
Pass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope
,
scope
,
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"During the multi_transformer pass, The scope should not be null."
));
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoder
FuseQKV
Pass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
}
}
AddStatis
(
fusion_count
);
AddStatis
(
fusion_count
);
}
}
FusedMultiTransformerEncoderPass
::
FusedMultiTransformerEncoderPass
()
{
FusedMultiTransformerEncoderFuseQKVPass
::
FusedMultiTransformerEncoderFuseQKVPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
AddInput
(
"X"
)
.
IsTensor
()
.
IsTensor
()
...
@@ -2041,6 +3248,23 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
...
@@ -2041,6 +3248,23 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
.
IsNumGT
(
0
)
.
IsNumGT
(
0
)
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
IsTensor
()
...
@@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
...
@@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
.
End
();
.
End
();
}
}
int
FusedMultiTransformerEncoderFuseQKV
Pass
::
BuildFusion
(
int
MultiDevicesFusedMultiTransformerEncoder
Pass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
if
(
enable_int8
)
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderFuseQKVPass with int8"
;
}
else
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderFuseQKVPass with fp"
;
}
// Create pattern.
// Create pattern.
patterns
::
FusedMultiTransformerEncoderFuseQKV
Pattern
patterns
::
MultiDevicesFusedMultiTransformerEncoder
Pattern
fused_multi_transformer_fuse_qkv
_pattern
(
pattern
,
name_scope
);
multi_devices_fused_multi_transformer
_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv
_pattern
();
multi_devices_fused_multi_transformer
_pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
c_identity
,
Node
*
layer_norm
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
layer_norm_variance
,
Node
*
matmul0
,
Node
*
matmul0_w
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd0_b
,
Node
*
split0_k_out
,
Node
*
eltadd1_b
,
Node
*
split0_v_out
,
Node
*
eltadd2_b
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_output
)
{
Node
*
ffn_act
,
auto
*
matmul0_op
=
matmul0
->
Op
();
Node
*
ffn_layer_norm_out
)
{
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
int
num_head
=
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
// Calc index of transformer layer by LayerNorm Scale name
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// This calculation assumes:
...
@@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
qkv_w
_tensor
=
auto
*
wq
_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
qkv_b_tensor
=
auto
*
wk_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
int
dim_embed
=
wq_tensor
->
dims
()[
0
];
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
qkv_w_tensor
->
dims
()[
1
]
/
3
/
dim_head
;
QKVWeightsBiasProcessFuseQKV
(
QKVWeightsBiasProcess
(
wq_tensor
,
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
wk_tensor
,
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
if
(
enable_int8
)
{
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
->
GetMutable
<
phi
::
DenseTensor
>
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
auto
*
ffn0_w_tensor
=
combined_w_desc
->
SetPersistable
(
true
);
scope
->
FindVar
(
ffn_matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
ffn1_w_tensor
=
scope
->
FindVar
(
ffn_matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
TransposeWeights
(
out_linear_w_tensor
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
TransposeWeights
(
ffn0_w_tensor
);
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
TransposeWeights
(
ffn1_w_tensor
);
combined_bias_desc
->
SetPersistable
(
true
);
}
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
?
"fused_multi_transformer_int8"
:
"fused_multi_transformer"
);
// 1. Input setting
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
...
@@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_b
_tensor
->
dtype
()));
framework
::
TransToProtoVarType
(
wq
_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
...
@@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
fill_const_op_desc
.
SetAttr
(
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
"dtype"
,
qkv_b
_tensor
->
dtype
())));
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
wq
_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
{
ffn_eltadd1_b
->
Name
()});
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_layer_norm_out
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
tru
e
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
fals
e
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
{
ffn_act
->
Op
()
->
Type
()});
// Quantization attribute/Input
// parallel ring id
if
(
enable_int8
)
{
auto
*
c_identity_op
=
c_identity
->
Op
();
// Set input scale
fused_multi_transformer_op_desc
.
SetAttr
(
"ring_id"
,
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
c_identity_op
->
GetAttr
(
"ring_id"
));
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"Input_scale_"
+
qkv_input_name
));
std
::
string
out_linear_input_name
=
matmul_linear_op
->
Input
(
"X"
)[
0
];
auto
out_linear_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"Input_scale_"
+
out_linear_input_name
));
std
::
string
ffn0_input_name
=
ffn_matmul_0_op
->
Input
(
"X"
)[
0
];
auto
ffn0_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"Input_scale_"
+
ffn0_input_name
));
std
::
string
ffn1_input_name
=
ffn_matmul_1_op
->
Input
(
"X"
)[
0
];
auto
ffn1_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
// Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto
qkv_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
auto
out_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"weight_scale"
));
auto
ffn0_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"weight_scale"
));
auto
ffn1_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"weight_scale"
));
auto
qkv_out_scales
=
std
::
vector
<
float
>
(
3
*
dim_embed
,
(
qkv_weight_scale
/
127.0
f
)
*
(
qkv_in_scale
/
127.0
f
));
auto
out_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
out_weight_scale
/
127.0
f
)
*
(
out_linear_in_scale
/
127.0
f
));
auto
ffn0_out_scales
=
std
::
vector
<
float
>
(
4
*
dim_embed
,
(
ffn0_weight_scale
/
127.0
f
)
*
(
ffn0_in_scale
/
127.0
f
));
auto
ffn1_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
ffn1_weight_scale
/
127.0
f
)
*
(
ffn1_in_scale
/
127.0
f
));
// Inverse input scale
qkv_in_scale
=
1.0
f
/
qkv_in_scale
;
out_linear_in_scale
=
1.0
f
/
out_linear_in_scale
;
ffn0_in_scale
=
1.0
f
/
ffn0_in_scale
;
ffn1_in_scale
=
1.0
f
/
ffn1_in_scale
;
fused_multi_transformer_op_desc
.
SetAttr
(
"qkv_in_scale"
,
std
::
vector
<
float
>
{
qkv_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"out_linear_in_scale"
,
std
::
vector
<
float
>
{
out_linear_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn1_in_scale"
,
std
::
vector
<
float
>
{
ffn0_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn2_in_scale"
,
std
::
vector
<
float
>
{
ffn1_in_scale
});
auto
qkv_out_scale_var
=
scope
->
Var
(
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_var
=
scope
->
Var
(
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_var
=
scope
->
Var
(
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
qkv_out_scale_data
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
3
*
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
out_out_scale_data
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
ffn0_out_scale_data
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
4
*
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
ffn1_out_scale_data
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2OutScale"
,
{
ffn_matmul1_w
->
Name
()
+
"_out_scale"
});
}
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
if
(
enable_int8
)
{
auto
qkv_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
IR_NODE_LINK_TO
(
qkv_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
out_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn0_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn1_out_scale_node
,
fused_multi_transformer
);
}
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
...
@@ -2477,291 +3589,374 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2477,291 +3589,374 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_layer_norm_out
);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
IR_NODE_UNLINK
(
split0_k_out
,
while0
);
IR_NODE_UNLINK
(
split0_v_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
split0_k_out
);
IR_NODE_UNLINK
(
while0
,
split0_v_out
);
};
};
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
_fuse_qkv
"
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
pass in
"
"
pass in
op compat failed."
;
"op compat failed."
;
return
;
return
;
}
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder(Fuse-QKV) fuse"
;
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
input0
,
input0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity0
,
c_identity0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity0_out
,
c_identity0_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity1
,
c_identity1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity1_out
,
c_identity1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity2
,
c_identity2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity2_out
,
c_identity2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity
,
ffn_c_identity
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity_out
,
ffn_c_identity_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul0
,
matmul0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul0_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
matmul0_w
,
matmul0_w
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv
_pattern
);
reshape2_0
,
reshape2_0
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
transpose2_0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul1
,
matmul1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
matmul1_w
,
matmul1_w
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
reshape2_1
,
reshape2_1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
scale_q
,
scale_q
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_q_out
,
scale_q_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul0_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
multi_devices_fused_multi_transformer_pattern
);
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_act
,
ffn_act
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_act_out
,
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_act_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
multi_devices_fused_multi_transformer_pattern
);
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_c_allreduce_sum
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum_out
,
ffn_c_allreduce_sum_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv
_pattern
)
multi_devices_fused_multi_transformer
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv
_pattern
)
ffn_output
,
ffn_output
,
multi_devices_fused_multi_transformer
_pattern
)
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd0
,
eltadd0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd0_b
,
eltadd0_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
eltadd1
,
eltadd1
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd1_b
,
eltadd1_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
eltadd2
,
eltadd2
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd2_b
,
eltadd2_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qk
,
matmul_qk
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qk_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd_qk
,
eltadd_qk
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
softmax_qk
,
softmax_qk
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv
_pattern
);
matmul_qkv
,
matmul_qkv
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
reshape2_qkv
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
matmul_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
matmul_linear
,
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_fuse_qkv
_pattern
)
multi_devices_fused_multi_transformer
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum
,
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
c_allreduce_sum
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum_out
,
c_allreduce_sum_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv
_pattern
)
multi_devices_fused_multi_transformer
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_fuse_qkv
_pattern
)
eltadd_out
,
eltadd_out
,
multi_devices_fused_multi_transformer
_pattern
)
fuse_creater
(
input0
,
fuse_creater
(
input0
,
c_identity0
,
layer_norm
,
layer_norm
,
layer_norm_scale
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
matmul0
,
matmul0_w
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd0_b
,
split0_k_out
,
eltadd1_b
,
split0_v_out
,
eltadd2_b
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_0
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
while0
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_matmul0
,
ffn_matmul0_w
,
ffn_matmul0_w
,
ffn_matmul1
,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_output
);
ffn_act
,
ffn_layer_norm_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
c_identity0
,
c_identity0_out
,
c_identity1
,
c_identity1_out
,
c_identity2
,
c_identity2_out
,
layer_norm
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
matmul0
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_0_out
,
split0
,
transpose2_1_out
,
split0_q
_out
,
transpose2_2
_out
,
s
plit0_k_out
,
s
cale_q
,
s
plit0_v
_out
,
s
cale_q
_out
,
matmul_qk
,
matmul_qk
,
matmul_qk_out
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk
,
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
...
@@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_out
,
matmul_linear_out
,
c_allreduce_sum
,
c_allreduce_sum_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_c_identity
,
ffn_c_identity_out
,
ffn_matmul0
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_matmul1_out
,
ffn_c_allreduce_sum
,
ffn_c_allreduce_sum_out
,
ffn_eltadd0
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_act
,
ffn_gelu_out
,
ffn_act_out
,
ffn_output
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
return
fusion_count
;
return
fusion_count
;
}
}
void
FusedMultiTransformerEncoderFuseQKVPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
MultiDevicesFusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope
,
scope
,
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"During the fused_multi_transformer_encoder pass, "
"During the multi_transformer pass, The scope should not be null."
));
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerEncoder
FuseQKV
Pass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
}
}
AddStatis
(
fusion_count
);
AddStatis
(
fusion_count
);
}
}
FusedMultiTransformerEncoderFuseQKV
Pass
::
MultiDevicesFusedMultiTransformerEncoder
Pass
::
FusedMultiTransformerEncoderFuseQKV
Pass
()
{
MultiDevicesFusedMultiTransformerEncoder
Pass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
AddInput
(
"X"
)
.
IsTensor
()
.
IsTensor
()
...
@@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass::
...
@@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsNumGT
(
0
)
.
IsNumGT
(
0
)
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
IsTensor
()
...
@@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass::
...
@@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsType
<
std
::
vector
<
int
>>
()
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"
matmul
"
))
AddOpCompat
(
OpCompat
(
"
scale
"
))
.
AddInput
(
"X"
)
.
AddInput
(
"X"
)
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddAttr
(
"alpha"
)
.
AddAttr
(
"scale"
)
.
IsNumGE
(
0.0
f
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
IsNumLE
(
1.0
f
)
.
End
()
.
End
()
.
AddAttr
(
"
transpose_X
"
)
.
AddAttr
(
"
bias
"
)
.
Is
BoolEQ
(
false
)
.
Is
NumEQ
(
0.
f
)
.
End
()
.
End
()
.
AddAttr
(
"
transpose_Y"
)
.
AddAttr
(
"
bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
IsType
<
bool
>
()
.
End
();
.
End
();
...
@@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass::
...
@@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsType
<
bool
>
()
.
IsType
<
bool
>
()
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"while"
))
AddOpCompat
(
OpCompat
(
"relu"
))
.
AddInput
(
"X"
)
// A set of variables, unconstrained
.
AddInput
(
"X"
)
.
End
()
.
AddInput
(
"Condition"
)
// An scalar
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddOutput
(
"Out"
)
// A set of variables, unconstrained
.
AddOutput
(
"Out"
)
.
End
()
.
IsTensor
()
.
AddOutput
(
"StepScopes"
)
// A vector of local scope, unconstrained
.
End
()
.
AddAttr
(
"sub_block"
)
.
IsType
<
framework
::
BlockDesc
*>
()
.
End
();
.
End
();
}
}
...
@@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
@@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
// Quantization attribute/Input
// Quantization attribute/Input
if
(
enable_int8
)
{
if
(
enable_int8
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
// Set input scale
// Set input scale
std
::
string
matmul_input_scale_suffix
=
c_identity_op
->
Input
(
"X"
)[
0
];
std
::
string
matmul_input_scale_suffix
=
c_identity_op
->
Input
(
"X"
)[
0
];
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
...
@@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto
ffn1_out_scale_var
=
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
qkv_out_scale_data
=
auto
*
qkv_out_scale_tensor
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
3
*
dim_embed
},
platform
::
CPUPlace
());
qkv_out_scale_tensor
->
Resize
({
3
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
qkv_out_scale_tensor
);
auto
qkv_out_scale_data
=
qkv_out_scale_tensor
->
data
<
float
>
();
memcpy
(
qkv_out_scale_data
,
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
out_out_scale_data
=
auto
*
out_out_scale_tensor
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
out_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
out_out_scale_tensor
);
auto
out_out_scale_data
=
out_out_scale_tensor
->
data
<
float
>
();
memcpy
(
out_out_scale_data
,
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
ffn0_out_scale_data
=
auto
*
ffn0_out_scale_tensor
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
4
*
dim_embed
},
platform
::
CPUPlace
());
ffn0_out_scale_tensor
->
Resize
({
4
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn0_out_scale_tensor
);
auto
ffn0_out_scale_data
=
ffn0_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn0_out_scale_data
,
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
ffn1_out_scale_data
=
auto
*
ffn1_out_scale_tensor
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
ffn1_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn1_out_scale_tensor
);
auto
ffn1_out_scale_data
=
ffn1_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn1_out_scale_data
,
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
...
@@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass,
...
@@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass,
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderPass
);
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderPass
);
REGISTER_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
,
REGISTER_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderFuseQKVPass
);
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderFuseQKVPass
);
REGISTER_PASS
(
multi_devices_fused_multi_transformer_encoder_pass
,
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderPass
);
REGISTER_PASS
(
REGISTER_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
,
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
);
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
);
...
@@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
...
@@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
.
LE
(
"matmul"
,
1
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
multi_devices_fused_multi_transformer_encoder_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
REGISTER_PASS_CAPABILITY
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
)
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
)
.
AddCombination
(
.
AddCombination
(
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
浏览文件 @
29eec2dd
...
@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul2
);
...
@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk
);
...
@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
PATTERN_DECL_NODE
(
attention_output
);
// while loop
// post layer_norm
PATTERN_DECL_NODE
(
while0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
};
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
...
@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_output
);
};
};
struct
MultiDevicesFusedMultiTransformerEncoderPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multi_devices_fused_multi_transformer_encoder"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
c_identity0
);
PATTERN_DECL_NODE
(
c_identity0_out
);
PATTERN_DECL_NODE
(
c_identity1
);
PATTERN_DECL_NODE
(
c_identity1_out
);
PATTERN_DECL_NODE
(
c_identity2
);
PATTERN_DECL_NODE
(
c_identity2_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul1_w
);
PATTERN_DECL_NODE
(
matmul2_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
matmul1_out
);
PATTERN_DECL_NODE
(
matmul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
c_allreduce_sum
);
PATTERN_DECL_NODE
(
c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// post layer_norm
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_c_identity
);
PATTERN_DECL_NODE
(
ffn_c_identity_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_act
);
PATTERN_DECL_NODE
(
ffn_act_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
(
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
(
...
@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...
@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
...
@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope
*
scope
)
const
;
Scope
*
scope
)
const
;
};
};
class
MultiDevicesFusedMultiTransformerEncoderPass
:
public
FusePassBase
{
public:
MultiDevicesFusedMultiTransformerEncoderPass
();
virtual
~
MultiDevicesFusedMultiTransformerEncoderPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multi_devices_fused_multi_transformer_encoder"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
:
public
FusePassBase
{
:
public
FusePassBase
{
public:
public:
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
浏览文件 @
29eec2dd
...
@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
...
@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
// FFN: fc1 -> (gelu) -> fc2
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias
_
0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias
_
1"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias1"
,
{
1024
});
return
param_scope
;
return
param_scope
;
}
}
...
@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
...
@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
TEST
(
FusedMultiTransformerEncoderPass
,
basic
)
{
TEST
(
FusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// inputs operator output
// --------------------------------------------------------------------
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (x, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (x, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (x, weights_2) matmul_v2 -> matmul_out2
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
...
@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (reshape_0) transpose2 -> transpose_0
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
...
@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_
out)
elementwise_add -> attention_out
// (eltadd_
linear)
elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm ->
ffn_
layer_norm_out
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
//
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
Layers
layers
;
// MHA: pre LayerNorm
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
x
,
weights_0
,
nullptr
,
false
,
false
);
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
x
,
weights_1
,
nullptr
,
false
,
false
);
auto
*
matmul_out_1
=
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
x
,
weights_2
,
nullptr
,
false
,
false
);
layers
.
matmul_v2
(
ln_out
,
weights_1
,
nullptr
,
false
,
true
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
ln_out
,
weights_2
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
...
@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// Link to decoder while block
// q scale
layers
.
while_loop
({
transpose_1
,
transpose_2
});
auto
*
scale_q
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
// MHA: QK matmul
// MHA: QK matmul
auto
*
matmul_qk
=
auto
*
matmul_qk
=
layers
.
matmul
(
transpose_0
,
transpose_1
,
nullptr
,
false
,
true
);
layers
.
matmul
_v2
(
scale_q
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// MHA: out Linear
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"
weightsl"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"
bias_l"
,
{
1024
},
true
);
auto
*
linear_matmut_out
=
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
tru
e
);
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
fals
e
);
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
// FFN: pre LayerNorm
// post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
auto
*
ln_out
=
layers
.
layer_norm
(
attention_out
,
ln_scale
,
ln_bias
)[
0
];
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
...
@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_
ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
layers
.
matmul_v2
(
ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
...
@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_out
=
layers
.
elementwise_add
(
ln_out
,
ffn_eltadd1_out
);
// FFN: post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
layers
.
layer_norm
(
ffn_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
5
6
,
num_nodes_after
+
5
8
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"node num in graph "
"should be %d, but the result is %d"
,
"should be %d, but the result is %d"
,
num_nodes_before
-
5
6
,
num_nodes_before
-
5
8
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
...
@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
.
IsPassCompatible
(
"fused_multi_transformer_encoder_pass"
));
.
IsPassCompatible
(
"fused_multi_transformer_encoder_pass"
));
}
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x) c_identity -> c_identity0_out
// (x) c_identity -> c_identity1_out
// (x) c_identity -> c_identity2_out
// (c_identity0_out, weights_0) matmul_v2 -> matmul_out0
// (c_identity1_out, weights_1) matmul_v2 -> matmul_out1
// (c_identity2_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
c_identity0_out
=
layers
.
c_identity
(
x
);
auto
*
c_identity1_out
=
layers
.
c_identity
(
x
);
auto
*
c_identity2_out
=
layers
.
c_identity
(
x
);
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
c_identity0_out
,
weights_0
,
nullptr
,
false
,
false
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
c_identity1_out
,
weights_1
,
nullptr
,
false
,
false
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
c_identity2_out
,
weights_2
,
nullptr
,
false
,
false
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
1024
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
matmul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
matmul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// q scale
auto
*
scale_q
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul_v2
(
scale_q
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"bias_l"
,
{
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
false
);
auto
*
c_allreduce_out
=
layers
.
c_allreduce_sum
(
linear_matmut_out
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
// post LayerNorm
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
attention_out
,
ln_scale
,
ln_bias
)[
0
];
auto
*
ffn_c_identity_out
=
layers
.
c_identity
(
ln_out
);
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_c_identity_out
,
ffn_weights0
,
nullptr
,
false
,
false
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
false
);
auto
*
ffn_allreduce_out
=
layers
.
c_allreduce_sum
(
ffn_matmul1_out
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
auto
*
ffn_out
=
layers
.
elementwise_add
(
ln_out
,
ffn_eltadd1_out
);
// FFN: post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
layers
.
layer_norm
(
ffn_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"enable_int8"
,
new
bool
(
false
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"multi_devices_fused_multi_transformer_encoder_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get multi_devices_fused_multi_transformer_encoder_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
70
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
70
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multi_devices_fused_multi_transformer_encoder_pass"
));
}
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
// inputs operator output
// inputs operator output
// --------------------------------------------------------------------
// --------------------------------------------------------------------
...
@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
...
@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
USE_PASS
(
fused_multi_transformer_encoder_pass
);
USE_PASS
(
fused_multi_transformer_encoder_pass
);
USE_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
29eec2dd
...
@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
...
@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fuse_multi_transformer_layer_pass"
,
"fuse_multi_transformer_layer_pass"
,
...
@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass"
,
//
"fused_multi_transformer_decoder_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"fuse_multi_transformer_layer_pass"
,
//
"fuse_multi_transformer_layer_pass"
,
//
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录