Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
29eec2dd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29eec2dd
编写于
1月 04, 2023
作者:
L
lzy
提交者:
GitHub
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)
上级
a2d7e1d7
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
2776 addition
and
1241 deletion
+2776
-1241
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+50
-36
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+6
-6
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+2341
-1140
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+164
-20
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+213
-39
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
29eec2dd
...
@@ -31,6 +31,8 @@ namespace framework {
...
@@ -31,6 +31,8 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
static
const
std
::
unordered_set
<
std
::
string
>
FFN_ACTS
{
"relu"
,
"gelu"
};
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
...
@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
...
@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
...
@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
...
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
...
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
...
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
...
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
浏览文件 @
29eec2dd
...
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
...
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
...
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
...
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
29eec2dd
...
@@ -25,44 +25,20 @@ namespace framework {
...
@@ -25,44 +25,20 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
PDNode
*
FusedMultiTransformerEncoderPattern
::
operator
()()
{
static
const
std
::
unordered_set
<
std
::
string
>
FFN_ACTS
{
"relu"
,
"gelu"
};
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
PDNode
*
FusedMultiTransformerEncoderPattern
::
operator
()()
{
auto
*
layer_norm
=
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
())
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
3
)
{
if
(
x
->
outputs
.
size
()
==
4
)
{
return
true
;
return
true
;
}
else
{
}
else
{
return
false
;
return
false
;
}
}
});
});
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// Q path Nodes
// Q path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
...
@@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
->
assert_is_op_input
(
"scale"
);
auto
*
scale_q
=
pattern
->
NewNode
(
scale_q_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_q_out_var
=
pattern
->
NewNode
(
scale_q_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// Q path Links
// Q path Links
matmul0
->
LinksFrom
({
layer_norm_out_var
,
matmul0_w_var
})
matmul0
->
LinksFrom
({
input0
,
matmul0_w_var
}).
LinksTo
({
matmul0_out_var
});
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
scale_q
->
LinksFrom
({
transpose2_0_out_var
}).
LinksTo
({
scale_q_out_var
});
// K path Nodes
// K path Nodes
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
...
@@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsOutput
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
->
assert_is_op_input
(
"while"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
// K path Links
// K path Links
matmul1
->
LinksFrom
({
layer_norm_out_var
,
matmul1_w_var
})
matmul1
->
LinksFrom
({
input0
,
matmul1_w_var
}).
LinksTo
({
matmul1_out_var
});
.
LinksTo
({
matmul1_out_var
});
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
.
LinksTo
({
eltadd1_out_var
});
.
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
...
@@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsOutput
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
->
assert_is_op_input
(
"while"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
// V path Links
// V path Links
matmul2
->
LinksFrom
({
layer_norm_out_var
,
matmul2_w_var
})
matmul2
->
LinksFrom
({
input0
,
matmul2_w_var
}).
LinksTo
({
matmul2_out_var
});
.
LinksTo
({
matmul2_out_var
});
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
.
LinksTo
({
eltadd2_out_var
});
.
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// QK path Nodes
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul
_v2
"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
auto
*
eltadd_qk
=
...
@@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
transpose2_0
_out_var
,
transpose2_1_out_var
})
matmul_qk
->
LinksFrom
({
scale_q
_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
...
@@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// while loop
// post-LayerNorm
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
auto
*
layer_norm
=
while0
->
LinksFrom
({
transpose2_1_out_var
,
transpose2_2_out_var
});
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_
layer_norm_variance_var
=
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_
layer_norm_variance_repr
())
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_
layer_norm_out_repr
())
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
ffn_layer_norm
layer_norm
->
LinksFrom
(
->
LinksFrom
({
attention_output
,
layer_norm_bias_var
,
layer_norm_scale_var
})
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
(
.
LinksTo
({
ffn_layer_norm_out_var
,
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
...
@@ -353,11 +316,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -353,11 +316,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
->
AsIntermediate
()
->
assert_is_op_input
(
"layer_norm"
);
ffn_matmul0
->
LinksFrom
({
ffn_
layer_norm_out_var
,
ffn_matmul0_w_var
})
ffn_matmul0
->
LinksFrom
({
layer_norm_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_eltadd1_out_var
})
ffn_eltadd_out
->
LinksFrom
({
layer_norm_out_var
,
ffn_eltadd1_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
ffn_layer_norm
->
LinksFrom
(
{
ffn_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
return
ffn_layer_norm_out_var
;
}
}
PDNode
*
FusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
PDNode
*
FusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
...
@@ -649,11 +645,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -649,11 +645,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
...
@@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
return
ffn_output
;
return
ffn_output
;
}
}
PDNode
*
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
PDNode
*
MultiDevicesFusedMultiTransformerEncoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
())
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
->
assert_is_op_input
(
"c_identity"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
// pre-LayerNorm
->
assert_more
([](
Node
*
x
)
{
auto
*
layer_norm
=
if
(
x
->
outputs
.
size
()
==
4
)
{
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
return
true
;
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
}
else
{
->
AsInput
()
return
false
;
->
assert_is_persistable_var
()
}
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
});
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// communication c_identity
// communication c_identity
auto
*
c_identity
=
auto
*
c_identity
0
=
pattern
->
NewNode
(
c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
pattern
->
NewNode
(
c_identity
0
_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity
_out_var
=
pattern
->
NewNode
(
c_identity
_out_repr
())
auto
*
c_identity
0_out_var
=
pattern
->
NewNode
(
c_identity0
_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
c_identity
->
LinksFrom
({
layer_norm_out_var
}).
LinksTo
({
c_identity_out_var
});
auto
*
c_identity1
=
pattern
->
NewNode
(
c_identity1_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity1_out_var
=
pattern
->
NewNode
(
c_identity1_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
c_identity2
=
pattern
->
NewNode
(
c_identity2_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity2_out_var
=
pattern
->
NewNode
(
c_identity2_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
c_identity0
->
LinksFrom
({
input0
}).
LinksTo
({
c_identity0_out_var
});
c_identity1
->
LinksFrom
({
input0
}).
LinksTo
({
c_identity1_out_var
});
c_identity2
->
LinksFrom
({
input0
}).
LinksTo
({
c_identity2_out_var
});
// Q
KV fused
path Nodes
// Q path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
AsInput
()
...
@@ -771,75 +761,137 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -771,75 +761,137 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
->
assert_is_op_input
(
"scale"
);
auto
*
scale_q
=
pattern
->
NewNode
(
scale_q_repr
())
->
assert_is_op
(
"scale"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
scale_q_out_var
=
pattern
->
NewNode
(
scale_q_out_repr
())
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
// Q
KV fused
path Links
// Q path Links
matmul0
->
LinksFrom
({
c_identity_out_var
,
matmul0_w_var
})
matmul0
->
LinksFrom
({
c_identity
0
_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
scale_q
->
LinksFrom
({
transpose2_0_out_var
}).
LinksTo
({
scale_q_out_var
});
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
// K path Nodes
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
->
assert_is_op_output
(
"scale"
)
auto
*
matmul1_w_var
=
pattern
->
NewNode
(
matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul1_out_var
=
pattern
->
NewNode
(
matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X
"
);
->
assert_is_op_input
(
"elementwise_add
"
);
auto
*
eltadd
_qk
=
auto
*
eltadd
1
=
pattern
->
NewNode
(
eltadd
_qk
_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd
1
_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd
_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk
_b_repr
())
auto
*
eltadd
1_b_var
=
pattern
->
NewNode
(
eltadd1
_b_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
auto
*
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax
"
);
->
assert_is_op_input
(
"reshape2
"
);
auto
*
softmax_qk
=
auto
*
reshape2_1
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax
"
);
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2
"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk
_out_repr
())
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1
_out_repr
())
->
assert_is_op_output
(
"
softmax
"
)
->
assert_is_op_output
(
"
reshape2
"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"
matmul_v2"
,
"X
"
);
->
assert_is_op_input
(
"
transpose2
"
);
// QK path Linsk
auto
*
transpose2_1
=
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
.
LinksTo
({
matmul_qk_out_var
});
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
->
assert_is_op_output
(
"transpose2"
)
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
->
AsIntermediate
()
.
LinksTo
({
eltadd_qk_out_var
});
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// K path Links
matmul1
->
LinksFrom
({
c_identity1_out_var
,
matmul1_w_var
})
.
LinksTo
({
matmul1_out_var
});
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
.
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// V path Nodes
auto
*
matmul2
=
pattern
->
NewNode
(
matmul2_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul2_w_var
=
pattern
->
NewNode
(
matmul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul2_out_var
=
pattern
->
NewNode
(
matmul2_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
// V path Links
matmul2
->
LinksFrom
({
c_identity2_out_var
,
matmul2_w_var
})
.
LinksTo
({
matmul2_out_var
});
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
.
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// QK path Linsk
matmul_qk
->
LinksFrom
({
scale_q_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
split0_v
_out_var
})
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
transpose2_2
_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// post-LayerNorm
auto
*
ffn_layer_norm
=
auto
*
layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_
layer_norm_variance_var
=
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_
layer_norm_variance_repr
())
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
As
Intermediate
()
->
As
Output
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_
layer_norm_out_repr
())
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
->
assert_is_op_input
(
"c_identity"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
ffn_layer_norm
layer_norm
->
LinksFrom
(
->
LinksFrom
({
attention_output
,
layer_norm_bias_var
,
layer_norm_scale_var
})
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
(
.
LinksTo
({
ffn_layer_norm_out_var
,
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// communication c_identity
// communication c_identity
auto
*
ffn_c_identity
=
auto
*
ffn_c_identity
=
...
@@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_c_identity
->
LinksFrom
({
ffn_
layer_norm_out_var
})
ffn_c_identity
->
LinksFrom
({
layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2
// Feed Forward fc1 -> gelu -> fc2
...
@@ -974,11 +1029,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -974,11 +1029,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
->
assert_is_op_input
(
"matmul_v2"
);
...
@@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
->
AsIntermediate
()
->
assert_is_op_input
(
"layer_norm"
);
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_eltadd1_out_var
})
ffn_eltadd_out
->
LinksFrom
({
layer_norm_out_var
,
ffn_eltadd1_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
// Feed Forward LayerNorm Nodes
}
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
}
// namespace patterns
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
template
<
typename
T
>
ffn_layer_norm
inline
void
QKVWeightsProcess
(
phi
::
DenseTensor
*
wq_tensor
,
->
LinksFrom
(
phi
::
DenseTensor
*
wk_tensor
,
{
ffn_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
phi
::
DenseTensor
*
wv_tensor
,
.
LinksTo
({
ffn_layer_norm_out_var
,
const
int
num_head
,
ffn_layer_norm_mean_var
,
const
int
dim_head
,
ffn_layer_norm_variance_var
});
const
int
dim_embed
)
{
auto
*
wq_data
=
wq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
wk_data
=
wk_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
wv_data
=
wv_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
combined_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
return
ffn_layer_norm_out_var
;
}
phi
::
DenseTensor
tmp_combined_w_tensor
;
PDNode
*
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
tmp_combined_w_data
=
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
tmp_combined_w_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
std
::
vector
<
T
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
// pre-LayerNorm
// Combine the three fc weights together.
auto
*
layer_norm
=
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
->
AsInput
()
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
->
assert_is_persistable_var
()
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
int
in_idx
=
l
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
->
AsInput
()
tmp_combined_w_data
[
out_idx
]
=
w_vec
[
i
][
in_idx
];
->
assert_is_persistable_var
()
}
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
}
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
}
->
AsOutput
()
}
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
wq_tensor
->
Resize
(
combined_w_dims
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
.
LinksTo
(
memcpy
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
T
)
*
wq_tensor
->
numel
());
}
template
<
typename
T
>
// communication c_identity
inline
void
QKVBiasProcess
(
phi
::
DenseTensor
*
bq_tensor
,
auto
*
c_identity
=
phi
::
DenseTensor
*
bk_tensor
,
pattern
->
NewNode
(
c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
phi
::
DenseTensor
*
bv_tensor
,
auto
*
c_identity_out_var
=
pattern
->
NewNode
(
c_identity_out_repr
())
const
int
num_head
,
->
AsIntermediate
()
const
int
dim_head
,
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
const
int
dim_embed
)
{
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
bq_data
=
bq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
c_identity
->
LinksFrom
({
layer_norm_out_var
}).
LinksTo
({
c_identity_out_var
});
auto
*
bk_data
=
bk_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
bv_data
=
bv_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
// QKV fused path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
phi
::
DenseTensor
tmp_combined_bias_tensor
;
auto
*
eltadd0
=
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
tmp_combined_bias_data
=
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
tmp_combined_bias_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
size_t
bias_size
=
bq_tensor
->
numel
();
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
// QKV fused path Links
matmul0
->
LinksFrom
({
c_identity_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
c_allreduce_sum
=
pattern
->
NewNode
(
c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
c_allreduce_sum_out_var
=
pattern
->
NewNode
(
c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
split0_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
c_allreduce_sum
->
LinksFrom
({
matmul_linear_out_var
})
.
LinksTo
({
c_allreduce_sum_out_var
});
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// communication c_identity
auto
*
ffn_c_identity
=
pattern
->
NewNode
(
ffn_c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
ffn_c_identity_out_var
=
pattern
->
NewNode
(
ffn_c_identity_out_repr
())
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_ops_input
(
FFN_ACTS
);
auto
*
ffn_act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_act_out_var
=
pattern
->
NewNode
(
ffn_act_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
ffn_c_allreduce_sum
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
ffn_c_allreduce_sum_out_var
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_act_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_eltadd1_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
}
// namespace patterns
template
<
typename
T
>
inline
void
QKVWeightsProcess
(
phi
::
DenseTensor
*
wq_tensor
,
phi
::
DenseTensor
*
wk_tensor
,
phi
::
DenseTensor
*
wv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
wq_data
=
wq_tensor
->
data
<
T
>
();
auto
*
wk_data
=
wk_tensor
->
data
<
T
>
();
auto
*
wv_data
=
wv_tensor
->
data
<
T
>
();
auto
combined_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
phi
::
DenseTensor
tmp_combined_w_tensor
;
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_combined_w_tensor
);
auto
*
tmp_combined_w_data
=
tmp_combined_w_tensor
.
data
<
T
>
();
std
::
vector
<
T
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
// Combine the three fc weights together.
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
tmp_combined_w_data
[
out_idx
]
=
w_vec
[
i
][
in_idx
];
}
}
}
}
wq_tensor
->
Resize
(
combined_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
wq_tensor
);
auto
*
new_combined_w_data
=
wq_tensor
->
data
<
T
>
();
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
T
)
*
wq_tensor
->
numel
());
}
template
<
typename
T
>
inline
void
QKVBiasProcess
(
phi
::
DenseTensor
*
bq_tensor
,
phi
::
DenseTensor
*
bk_tensor
,
phi
::
DenseTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
bq_data
=
bq_tensor
->
data
<
T
>
();
auto
*
bk_data
=
bk_tensor
->
data
<
T
>
();
auto
*
bv_data
=
bv_tensor
->
data
<
T
>
();
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
phi
::
DenseTensor
tmp_combined_bias_tensor
;
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_combined_bias_tensor
);
auto
*
tmp_combined_bias_data
=
tmp_combined_bias_tensor
.
data
<
T
>
();
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
T
)
*
bias_size
);
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
T
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
dev_ctx
->
Alloc
<
T
>
(
bq_tensor
);
bq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
new_combined_bias_data
=
bq_tensor
->
data
<
T
>
();
memcpy
(
new_combined_bias_data
,
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
T
)
*
bq_tensor
->
numel
());
sizeof
(
T
)
*
bq_tensor
->
numel
());
}
}
inline
void
QKVWeightsBiasProcess
(
phi
::
DenseTensor
*
wq_tensor
,
phi
::
DenseTensor
*
wk_tensor
,
phi
::
DenseTensor
*
wv_tensor
,
phi
::
DenseTensor
*
bq_tensor
,
phi
::
DenseTensor
*
bk_tensor
,
phi
::
DenseTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
switch
(
wq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVWeightsProcess
<
platform
::
float16
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcess
<
float
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcess
<
int8_t
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
bq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcess
<
platform
::
float16
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcess
<
float
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
template
<
typename
T
>
inline
void
QKVWeightsProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
qkv_w_data
=
qkv_w_tensor
->
data
<
T
>
();
auto
transpose_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
phi
::
DenseTensor
tmp_transpose_w_tensor
;
tmp_transpose_w_tensor
.
Resize
(
transpose_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_transpose_w_tensor
);
auto
*
tmp_transpose_w_data
=
tmp_transpose_w_tensor
.
data
<
T
>
();
// transpose qkv matmul Y to QKVWeights
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
3
*
dim_head
+
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_w_data
[
out_idx
]
=
qkv_w_data
[
in_idx
];
}
}
}
}
qkv_w_tensor
->
Resize
(
transpose_w_dims
);
dev_ctx
->
Alloc
<
T
>
(
qkv_w_tensor
);
auto
*
new_transpose_w_data
=
qkv_w_tensor
->
data
<
T
>
();
memcpy
(
new_transpose_w_data
,
tmp_transpose_w_data
,
sizeof
(
T
)
*
qkv_w_tensor
->
numel
());
}
template
<
typename
T
>
inline
void
QKVBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_b_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
*
qkv_b_data
=
qkv_b_tensor
->
data
<
T
>
();
auto
transpose_b_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
phi
::
DenseTensor
tmp_transpose_b_tensor
;
tmp_transpose_b_tensor
.
Resize
(
transpose_b_dims
);
dev_ctx
->
Alloc
<
T
>
(
&
tmp_transpose_b_tensor
);
auto
*
tmp_transpose_b_data
=
tmp_transpose_b_tensor
.
data
<
T
>
();
// transpose qkv elemenwise_add Y to QKVBias
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
int
in_idx
=
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_b_data
[
out_idx
]
=
qkv_b_data
[
in_idx
];
}
}
}
qkv_b_tensor
->
Resize
({
3
,
num_head
,
dim_head
});
dev_ctx
->
Alloc
<
T
>
(
qkv_b_tensor
);
auto
*
new_transpose_b_data
=
qkv_b_tensor
->
data
<
T
>
();
memcpy
(
new_transpose_b_data
,
tmp_transpose_b_data
,
sizeof
(
T
)
*
qkv_b_tensor
->
numel
());
}
inline
void
QKVWeightsBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
phi
::
DenseTensor
*
qkv_b_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
switch
(
qkv_w_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVWeightsProcessFuseQKV
<
platform
::
float16
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcessFuseQKV
<
float
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcessFuseQKV
<
int8_t
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
qkv_b_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcessFuseQKV
<
platform
::
float16
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcessFuseQKV
<
float
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
// Just use for fused_multi_transformer_int8
inline
void
TransposeWeights
(
phi
::
DenseTensor
*
weight_tensor
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
int
m
=
weight_tensor
->
dims
()[
0
];
int
n
=
weight_tensor
->
dims
()[
1
];
phi
::
DenseTensor
tmp_weight_tensor
;
tmp_weight_tensor
.
Resize
({
n
,
m
});
dev_ctx
->
Alloc
<
int8_t
>
(
&
tmp_weight_tensor
);
auto
tmp_weight_data
=
tmp_weight_tensor
.
data
<
int8_t
>
();
auto
weight_data
=
weight_tensor
->
data
<
int8_t
>
();
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
int
in_idx
=
i
*
n
+
j
;
int
out_idx
=
j
*
m
+
i
;
tmp_weight_data
[
out_idx
]
=
weight_data
[
in_idx
];
}
}
weight_tensor
->
Resize
({
n
,
m
});
dev_ctx
->
Alloc
<
int8_t
>
(
weight_tensor
);
auto
new_weight_data
=
weight_tensor
->
data
<
int8_t
>
();
memcpy
(
new_weight_data
,
tmp_weight_data
,
sizeof
(
int8_t
)
*
m
*
n
);
}
inline
Node
*
CreatePersistableVarNode
(
Graph
*
graph
,
const
std
::
string
&
name
)
{
auto
var_desc
=
VarDesc
(
name
);
var_desc
.
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
var_desc
.
SetPersistable
(
true
);
auto
node
=
graph
->
CreateVarNode
(
&
var_desc
);
return
node
;
}
int
FusedMultiTransformerEncoderPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
if
(
enable_int8
)
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with int8"
;
}
else
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with fp"
;
}
// Create pattern.
patterns
::
FusedMultiTransformerEncoderPattern
fused_multi_transformer_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
matmul0
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
reshape2_0
,
Node
*
matmul_linear
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_layer_norm_out
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
wq_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wk_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq_tensor.shape[1] and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
wq_tensor
->
dims
()[
1
]
/
dim_head
;
QKVWeightsBiasProcess
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
if
(
enable_int8
)
{
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
ffn0_w_tensor
=
scope
->
FindVar
(
ffn_matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
ffn1_w_tensor
=
scope
->
FindVar
(
ffn_matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
TransposeWeights
(
out_linear_w_tensor
);
TransposeWeights
(
ffn0_w_tensor
);
TransposeWeights
(
ffn1_w_tensor
);
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
combined_bias_desc
->
SetPersistable
(
true
);
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
?
"fused_multi_transformer_int8"
:
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// CacheKV input
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
bq_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
OpDesc
fill_const_op_desc
(
layer_norm
->
Op
()
->
Block
());
fill_const_op_desc
.
SetType
(
"fill_constant_batch_size_like"
);
fill_const_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
fill_const_op_desc
.
SetOutput
(
"Out"
,
{
cache_kv
->
Name
()});
std
::
vector
<
int
>
shape
=
{
2
,
-
1
,
num_head
,
1024
,
dim_head
};
fill_const_op_desc
.
SetAttr
(
"shape"
,
shape
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
bq_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_layer_norm_out
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
false
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
{
ffn_act
->
Op
()
->
Type
()});
// Quantization attribute/Input
if
(
enable_int8
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
// Set input scale
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"Input_scale_"
+
qkv_input_name
));
std
::
string
out_linear_input_name
=
matmul_linear_op
->
Input
(
"X"
)[
0
];
auto
out_linear_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"Input_scale_"
+
out_linear_input_name
));
std
::
string
ffn0_input_name
=
ffn_matmul_0_op
->
Input
(
"X"
)[
0
];
auto
ffn0_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"Input_scale_"
+
ffn0_input_name
));
std
::
string
ffn1_input_name
=
ffn_matmul_1_op
->
Input
(
"X"
)[
0
];
auto
ffn1_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
// Calc outscale and Set them
auto
qkv_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
auto
out_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"weight_scale"
));
auto
ffn0_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"weight_scale"
));
auto
ffn1_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"weight_scale"
));
auto
qkv_out_scales
=
std
::
vector
<
float
>
(
3
*
dim_embed
,
(
qkv_weight_scale
/
127.0
f
)
*
(
qkv_in_scale
/
127.0
f
));
auto
out_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
out_weight_scale
/
127.0
f
)
*
(
out_linear_in_scale
/
127.0
f
));
auto
ffn0_out_scales
=
std
::
vector
<
float
>
(
4
*
dim_embed
,
(
ffn0_weight_scale
/
127.0
f
)
*
(
ffn0_in_scale
/
127.0
f
));
auto
ffn1_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
ffn1_weight_scale
/
127.0
f
)
*
(
ffn1_in_scale
/
127.0
f
));
// Inverse input scale
qkv_in_scale
=
1.0
f
/
qkv_in_scale
;
out_linear_in_scale
=
1.0
f
/
out_linear_in_scale
;
ffn0_in_scale
=
1.0
f
/
ffn0_in_scale
;
ffn1_in_scale
=
1.0
f
/
ffn1_in_scale
;
fused_multi_transformer_op_desc
.
SetAttr
(
"qkv_in_scale"
,
std
::
vector
<
float
>
{
qkv_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"out_linear_in_scale"
,
std
::
vector
<
float
>
{
out_linear_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn1_in_scale"
,
std
::
vector
<
float
>
{
ffn0_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn2_in_scale"
,
std
::
vector
<
float
>
{
ffn1_in_scale
});
auto
qkv_out_scale_var
=
scope
->
Var
(
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_var
=
scope
->
Var
(
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_var
=
scope
->
Var
(
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
*
qkv_out_scale_tensor
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
qkv_out_scale_tensor
->
Resize
({
3
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
qkv_out_scale_tensor
);
auto
qkv_out_scale_data
=
qkv_out_scale_tensor
->
data
<
float
>
();
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
*
out_out_scale_tensor
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
out_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
out_out_scale_tensor
);
auto
out_out_scale_data
=
out_out_scale_tensor
->
data
<
float
>
();
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
*
ffn0_out_scale_tensor
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
ffn0_out_scale_tensor
->
Resize
({
4
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn0_out_scale_tensor
);
auto
ffn0_out_scale_data
=
ffn0_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
*
ffn1_out_scale_tensor
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
ffn1_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn1_out_scale_tensor
);
auto
ffn1_out_scale_data
=
ffn1_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2OutScale"
,
{
ffn_matmul1_w
->
Name
()
+
"_out_scale"
});
}
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
if
(
enable_int8
)
{
auto
qkv_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
IR_NODE_LINK_TO
(
qkv_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
out_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn0_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn1_out_scale_node
,
fused_multi_transformer
);
}
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fill_const_op
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_layer_norm_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder pass in "
"op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1
,
matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_w
,
matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_q
,
scale_q
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_q_out
,
scale_q_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_act
,
ffn_act
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_act_out
,
ffn_act_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
matmul0
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
reshape2_0
,
matmul_linear
,
matmul_linear_w
,
eltadd_linear_b
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0
,
ffn_matmul0_w
,
ffn_matmul1
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_layer_norm_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
scale_q
,
scale_q_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_act
,
ffn_act_out
,
ffn_output
,
ffn_eltadd_out
});
inline
void
QKVWeightsBiasProcess
(
phi
::
DenseTensor
*
wq_tensor
,
// Remove unneeded nodes.
phi
::
DenseTensor
*
wk_tensor
,
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
phi
::
DenseTensor
*
wv_tensor
,
++
fusion_count
;
phi
::
DenseTensor
*
bq_tensor
,
};
phi
::
DenseTensor
*
bk_tensor
,
gpd
(
graph
,
handler
);
phi
::
DenseTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
switch
(
wq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVWeightsProcess
<
platform
::
float16
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcess
<
float
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcess
<
int8_t
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
bq_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcess
<
platform
::
float16
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcess
<
float
>
(
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
template
<
typename
T
>
return
fusion_count
;
inline
void
QKVWeightsProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
}
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
qkv_w_data
=
qkv_w_tensor
->
data
<
T
>
();
auto
transpose_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
phi
::
DenseTensor
tmp_transpose_w_tensor
;
void
FusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
tmp_transpose_w_tensor
.
Resize
(
transpose_w_dims
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
tmp_transpose_w_data
=
auto
*
scope
=
param_scope
();
tmp_transpose_w_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multi_transformer pass, The scope should not be null."
));
// transpose qkv matmul Y to QKVWeights
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
if
(
fusion_count
>
0
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
3
*
dim_head
+
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_w_data
[
out_idx
]
=
qkv_w_data
[
in_idx
];
}
}
}
}
}
AddStatis
(
fusion_count
);
qkv_w_tensor
->
Resize
(
transpose_w_dims
);
auto
*
new_transpose_w_data
=
qkv_w_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
memcpy
(
new_transpose_w_data
,
tmp_transpose_w_data
,
sizeof
(
T
)
*
qkv_w_tensor
->
numel
());
}
}
template
<
typename
T
>
FusedMultiTransformerEncoderPass
::
FusedMultiTransformerEncoderPass
()
{
inline
void
QKVBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_b_tensor
,
AddOpCompat
(
OpCompat
(
"layer_norm"
))
const
int
num_head
,
.
AddInput
(
"X"
)
const
int
dim_head
,
.
IsTensor
()
const
int
dim_embed
)
{
.
End
()
auto
*
qkv_b_data
=
qkv_b_tensor
->
data
<
T
>
();
.
AddInput
(
"Scale"
)
auto
transpose_b_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
phi
::
DenseTensor
tmp_transpose_b_tensor
;
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
tmp_transpose_b_tensor
.
Resize
(
transpose_b_dims
);
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
auto
*
tmp_transpose_b_data
=
.
IsTensor
()
tmp_transpose_b_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
// transpose qkv elemenwise_add Y to QKVBias
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
.
AddInput
(
"X"
)
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
.
IsTensor
()
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
.
End
()
int
out_idx
=
i
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
.
AddInput
(
"Y"
)
int
in_idx
=
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
.
IsTensor
()
tmp_transpose_b_data
[
out_idx
]
=
qkv_b_data
[
in_idx
];
.
End
()
}
.
AddOutput
(
"Out"
)
}
.
IsTensor
()
}
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
qkv_b_tensor
->
Resize
({
3
,
num_head
,
dim_head
});
AddOpCompat
(
OpCompat
(
"scale"
))
auto
*
new_transpose_b_data
=
.
AddInput
(
"X"
)
qkv_b_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
.
IsTensor
()
memcpy
(
new_transpose_b_data
,
.
End
()
tmp_transpose_b_data
,
.
AddOutput
(
"Out"
)
sizeof
(
T
)
*
qkv_b_tensor
->
numel
());
.
IsTensor
()
}
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
inline
void
QKVWeightsBiasProcessFuseQKV
(
phi
::
DenseTensor
*
qkv_w_tensor
,
AddOpCompat
(
OpCompat
(
"softmax"
))
phi
::
DenseTensor
*
qkv_b_tensor
,
.
AddInput
(
"X"
)
const
int
num_head
,
.
IsTensor
()
const
int
dim_head
,
.
End
()
const
int
dim_embed
)
{
.
AddOutput
(
"Out"
)
switch
(
qkv_w_tensor
->
dtype
())
{
.
IsTensor
()
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
.
End
()
QKVWeightsProcessFuseQKV
<
platform
::
float16
>
(
.
AddAttr
(
"axis"
)
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
break
;
.
End
();
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVWeightsProcessFuseQKV
<
float
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
INT8
:
QKVWeightsProcessFuseQKV
<
int8_t
>
(
qkv_w_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."
));
break
;
}
switch
(
qkv_b_tensor
->
dtype
())
{
case
paddle
::
experimental
::
DataType
::
FLOAT16
:
QKVBiasProcessFuseQKV
<
platform
::
float16
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
case
paddle
::
experimental
::
DataType
::
FLOAT32
:
QKVBiasProcessFuseQKV
<
float
>
(
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."
));
break
;
}
}
// Just use for fused_multi_transformer_int8
AddOpCompat
(
OpCompat
(
"gelu"
))
inline
void
TransposeWeights
(
phi
::
DenseTensor
*
weight_tensor
)
{
.
AddInput
(
"X"
)
int
m
=
weight_tensor
->
dims
()[
0
];
.
IsTensor
()
int
n
=
weight_tensor
->
dims
()[
1
];
.
End
()
phi
::
DenseTensor
tmp_weight_tensor
;
.
AddOutput
(
"Out"
)
auto
tmp_weight_data
=
.
IsTensor
()
tmp_weight_tensor
.
mutable_data
<
int8_t
>
({
n
,
m
},
platform
::
CPUPlace
());
.
End
()
auto
weight_data
=
weight_tensor
->
data
<
int8_t
>
();
.
AddAttr
(
"approximate"
)
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
.
IsType
<
bool
>
()
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
.
End
();
int
in_idx
=
i
*
n
+
j
;
int
out_idx
=
j
*
m
+
i
;
tmp_weight_data
[
out_idx
]
=
weight_data
[
in_idx
];
}
}
weight_tensor
->
Resize
({
n
,
m
});
auto
new_weight_data
=
weight_tensor
->
mutable_data
<
int8_t
>
(
platform
::
CPUPlace
());
memcpy
(
new_weight_data
,
tmp_weight_data
,
sizeof
(
int8_t
)
*
m
*
n
);
}
inline
Node
*
CreatePersistableVarNode
(
Graph
*
graph
,
const
std
::
string
&
name
)
{
AddOpCompat
(
OpCompat
(
"relu"
))
auto
var_desc
=
VarDesc
(
name
);
.
AddInput
(
"X"
)
var_desc
.
SetDataType
(
framework
::
proto
::
VarType
::
FP32
);
.
IsTensor
()
var_desc
.
SetPersistable
(
true
);
.
End
()
auto
node
=
graph
->
CreateVarNode
(
&
var_desc
);
.
AddOutput
(
"Out"
)
return
node
;
.
IsTensor
()
.
End
();
}
}
int
FusedMultiTransformerEncoderPass
::
BuildFusion
(
Graph
*
graph
,
int
FusedMultiTransformerEncoderFuseQKVPass
::
BuildFusion
(
const
std
::
string
&
name_scope
,
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
if
(
enable_int8
)
{
if
(
enable_int8
)
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with int8"
;
VLOG
(
3
)
<<
"FusedMultiTransformerEncoder
FuseQKV
Pass with int8"
;
}
else
{
}
else
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderPass with fp"
;
VLOG
(
3
)
<<
"FusedMultiTransformerEncoder
FuseQKV
Pass with fp"
;
}
}
// Create pattern.
// Create pattern.
patterns
::
FusedMultiTransformerEncoder
Pattern
fused_multi_transformer_pattern
(
patterns
::
FusedMultiTransformerEncoder
FuseQKVPattern
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_pattern
();
fused_multi_transformer_
fuse_qkv_
pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
...
@@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
layer_norm_variance
,
Node
*
layer_norm_variance
,
Node
*
matmul0
,
Node
*
matmul0
,
Node
*
matmul0_w
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
split0_k_out
,
Node
*
eltadd2_b
,
Node
*
split0_v_out
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear
,
Node
*
matmul_linear
,
...
@@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
wq
_tensor
=
auto
*
qkv_w
_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wk_tensor
=
auto
*
qkv_b_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq
_tensor.shape[1]
and dim_head
// 2. calculate num_head according to wq
kv_tensor.shape[1]/3
and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
layer_norm_bias_tensor
=
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
wq_tensor
->
dims
()[
1
]
/
dim_head
;
int
num_head
=
qkv_w_tensor
->
dims
()[
1
]
/
3
/
dim_head
;
QKVWeightsBiasProcess
(
wq_tensor
,
QKVWeightsBiasProcessFuseQKV
(
wk_tensor
,
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
if
(
enable_int8
)
{
if
(
enable_int8
)
{
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
...
@@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
TransposeWeights
(
ffn1_w_tensor
);
TransposeWeights
(
ffn1_w_tensor
);
}
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
combined_bias_desc
->
SetPersistable
(
true
);
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
...
@@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
bq
_tensor
->
dtype
()));
framework
::
TransToProtoVarType
(
qkv_b
_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
...
@@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
fill_const_op_desc
.
SetAttr
(
"dtype"
,
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
bq
_tensor
->
dtype
())));
qkv_b
_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
// Quantization attribute/Input
// Quantization attribute/Input
if
(
enable_int8
)
{
if
(
enable_int8
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
// Set input scale
// Set input scale
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
...
@@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
// Calc outscale and Set them
// Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto
qkv_weight_scale
=
auto
qkv_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
auto
out_weight_scale
=
auto
out_weight_scale
=
...
@@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
ffn1_out_scale_var
=
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
qkv_out_scale_data
=
auto
*
qkv_out_scale_tensor
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
3
*
dim_embed
},
platform
::
CPUPlace
());
qkv_out_scale_tensor
->
Resize
({
3
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
qkv_out_scale_tensor
);
auto
qkv_out_scale_data
=
qkv_out_scale_tensor
->
data
<
float
>
();
memcpy
(
qkv_out_scale_data
,
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
out_out_scale_data
=
auto
*
out_out_scale_tensor
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
out_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
out_out_scale_tensor
);
auto
out_out_scale_data
=
out_out_scale_tensor
->
data
<
float
>
();
memcpy
(
out_out_scale_data
,
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
ffn0_out_scale_data
=
auto
*
ffn0_out_scale_tensor
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
4
*
dim_embed
},
platform
::
CPUPlace
());
ffn0_out_scale_tensor
->
Resize
({
4
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn0_out_scale_tensor
);
auto
ffn0_out_scale_data
=
ffn0_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn0_out_scale_data
,
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
ffn1_out_scale_data
=
auto
*
ffn1_out_scale_tensor
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
ffn1_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn1_out_scale_tensor
);
auto
ffn1_out_scale_data
=
ffn1_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn1_out_scale_data
,
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
...
@@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
while_Xs
.
erase
(
std
::
remove
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
transpose2_1_out
->
Name
()),
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
transpose2_2_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
matmul1_w
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
matmul2_w
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
eltadd1_b
->
Name
()),
std
::
end
(
while_Xs
));
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
while_Xs
.
erase
(
std
::
remove
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
eltadd2_b
->
Name
()),
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Xs
));
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
...
@@ -1670,13 +2901,13 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1670,13 +2901,13 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// 1. delete k, v
// 1. delete k, v
// 2. add cache_kv
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
while_Outs
.
erase
(
std
::
end
(
while_Outs
),
std
::
remove
(
transpose2_1
_out
->
Name
()),
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_k
_out
->
Name
()),
std
::
end
(
while_Outs
));
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
while_Outs
.
erase
(
std
::
end
(
while_Outs
),
std
::
remove
(
transpose2_2
_out
->
Name
()),
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_v
_out
->
Name
()),
std
::
end
(
while_Outs
));
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
...
@@ -1684,213 +2915,200 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1684,213 +2915,200 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// link CacheKV to while
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
// unlink origin KV output to while
IR_NODE_UNLINK
(
transpose2_1_out
,
while0
);
IR_NODE_UNLINK
(
split0_k_out
,
while0
);
IR_NODE_UNLINK
(
transpose2_2_out
,
while0
);
IR_NODE_UNLINK
(
split0_v_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
transpose2_1_out
);
IR_NODE_UNLINK
(
while0
,
split0_k_out
);
IR_NODE_UNLINK
(
while0
,
transpose2_2_out
);
IR_NODE_UNLINK
(
while0
,
split0_v_out
);
// unlink KV weight/bias to while after merged into Q weight/bias
IR_NODE_UNLINK
(
matmul1_w
,
while0
);
IR_NODE_UNLINK
(
matmul2_w
,
while0
);
IR_NODE_UNLINK
(
eltadd1_b
,
while0
);
IR_NODE_UNLINK
(
eltadd2_b
,
while0
);
};
};
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
pass in
"
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
_fuse_qkv
"
"op compat failed."
;
"
pass in
op compat failed."
;
return
;
return
;
}
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
VLOG
(
4
)
<<
"handle MultiTransformer encoder(Fuse-QKV) fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1
,
matmul1
,
fused_multi_transformer_pattern
);
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
fused_multi_transformer_pattern
);
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul
1_w
,
matmul1_w
,
fused_multi_transformer
_pattern
);
matmul
0
,
matmul0
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
fused_multi_transformer
_pattern
);
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
fused_multi_transformer
_pattern
);
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
fused_multi_transformer_pattern
);
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
fused_multi_transformer_pattern
);
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
fused_multi_transformer_pattern
);
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
fused_multi_transformer
_pattern
);
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
fused_multi_transformer_pattern
);
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
fused_multi_transformer_pattern
)
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_pattern
);
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_pattern
);
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_pattern
);
ffn_matmul0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
eltadd0
,
ffn_eltadd0
,
fused_multi_transformer
_pattern
);
ffn_
matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
_b
,
ffn_eltadd0_b
,
fused_multi_transformer
_pattern
);
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer
_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu_out
,
fused_multi_transformer
_pattern
);
ffn_
act_out
,
ffn_act_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_pattern
);
ffn_matmul1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
eltadd1
,
ffn_eltadd1
,
fused_multi_transformer
_pattern
);
ffn_
matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
_b
,
ffn_eltadd1_b
,
fused_multi_transformer
_pattern
);
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_pattern
)
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_pattern
);
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd
1_b
,
eltadd1_b
,
fused_multi_transformer
_pattern
);
eltadd
0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd
1_out
,
eltadd1_out
,
fused_multi_transformer
_pattern
);
eltadd
0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
fused_multi_transformer_pattern
);
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
fused_multi_transformer
_pattern
);
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer
_pattern
);
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer
_pattern
);
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_pattern
);
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_pattern
);
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_pattern
);
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_pattern
);
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_pattern
);
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_pattern
);
fused_multi_transformer_
fuse_qkv_
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_pattern
)
matmul_linear
,
matmul_linear
,
fused_multi_transformer_
fuse_qkv_
pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_pattern
)
matmul_linear_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_multi_transformer_fuse_qkv_pattern
)
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer
_pattern
)
fused_multi_transformer_fuse_qkv
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer
_pattern
)
while0
,
while0
,
fused_multi_transformer_fuse_qkv
_pattern
)
fuse_creater
(
input0
,
fuse_creater
(
input0
,
layer_norm
,
layer_norm
,
...
@@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_variance
,
layer_norm_variance
,
matmul0
,
matmul0
,
matmul0_w
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd0_b
,
eltadd1_b
,
split0_k_out
,
eltadd2_b
,
split0_v_out
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_0
,
matmul_linear
,
matmul_linear
,
...
@@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
matmul0
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_0_out
,
transpose2_1_out
,
split0
,
transpose2_2_out
,
split0_q_out
,
split0_k_out
,
split0_v_out
,
matmul_qk
,
matmul_qk
,
matmul_qk_out
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk
,
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
...
@@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
return
fusion_count
;
return
fusion_count
;
}
}
void
FusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
FusedMultiTransformerEncoder
FuseQKV
Pass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope
,
scope
,
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"During the multi_transformer pass, The scope should not be null."
));
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoder
FuseQKV
Pass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
}
}
AddStatis
(
fusion_count
);
AddStatis
(
fusion_count
);
}
}
FusedMultiTransformerEncoderPass
::
FusedMultiTransformerEncoderPass
()
{
FusedMultiTransformerEncoderFuseQKVPass
::
FusedMultiTransformerEncoderFuseQKVPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
AddInput
(
"X"
)
.
IsTensor
()
.
IsTensor
()
...
@@ -2041,6 +3248,23 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
...
@@ -2041,6 +3248,23 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
.
IsNumGT
(
0
)
.
IsNumGT
(
0
)
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
IsTensor
()
...
@@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
...
@@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
.
End
();
.
End
();
}
}
int
FusedMultiTransformerEncoderFuseQKV
Pass
::
BuildFusion
(
int
MultiDevicesFusedMultiTransformerEncoder
Pass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
if
(
enable_int8
)
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderFuseQKVPass with int8"
;
}
else
{
VLOG
(
3
)
<<
"FusedMultiTransformerEncoderFuseQKVPass with fp"
;
}
// Create pattern.
// Create pattern.
patterns
::
FusedMultiTransformerEncoderFuseQKV
Pattern
patterns
::
MultiDevicesFusedMultiTransformerEncoder
Pattern
fused_multi_transformer_fuse_qkv
_pattern
(
pattern
,
name_scope
);
multi_devices_fused_multi_transformer
_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv
_pattern
();
multi_devices_fused_multi_transformer
_pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
c_identity
,
Node
*
layer_norm
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
layer_norm_variance
,
Node
*
matmul0
,
Node
*
matmul0_w
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd0_b
,
Node
*
split0_k_out
,
Node
*
eltadd1_b
,
Node
*
split0_v_out
,
Node
*
eltadd2_b
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_output
)
{
Node
*
ffn_act
,
auto
*
matmul0_op
=
matmul0
->
Op
();
Node
*
ffn_layer_norm_out
)
{
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
*
ffn_matmul_0_op
=
ffn_matmul0
->
Op
();
int
num_head
=
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
// Calc index of transformer layer by LayerNorm Scale name
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// This calculation assumes:
...
@@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
qkv_w
_tensor
=
auto
*
wq
_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
qkv_b_tensor
=
auto
*
wk_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
int
dim_embed
=
wq_tensor
->
dims
()[
0
];
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto
reshape_desc
=
reshape2_0
->
Op
();
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
auto
*
layer_norm_bias_tensor
=
scope
->
FindVar
(
layer_norm_bias
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int
dim_embed
=
layer_norm_bias_tensor
->
dims
()[
0
];
int
num_head
=
qkv_w_tensor
->
dims
()[
1
]
/
3
/
dim_head
;
QKVWeightsBiasProcessFuseQKV
(
QKVWeightsBiasProcess
(
wq_tensor
,
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
wk_tensor
,
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
if
(
enable_int8
)
{
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
out_linear_w_tensor
=
scope
->
FindVar
(
matmul_linear_w
->
Name
())
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
->
GetMutable
<
phi
::
DenseTensor
>
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
auto
*
ffn0_w_tensor
=
combined_w_desc
->
SetPersistable
(
true
);
scope
->
FindVar
(
ffn_matmul0_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
ffn1_w_tensor
=
scope
->
FindVar
(
ffn_matmul1_w
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
TransposeWeights
(
out_linear_w_tensor
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
TransposeWeights
(
ffn0_w_tensor
);
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
TransposeWeights
(
ffn1_w_tensor
);
combined_bias_desc
->
SetPersistable
(
true
);
}
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
enable_int8
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
?
"fused_multi_transformer_int8"
:
"fused_multi_transformer"
);
// 1. Input setting
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
...
@@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_b
_tensor
->
dtype
()));
framework
::
TransToProtoVarType
(
wq
_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
...
@@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
fill_const_op_desc
.
SetAttr
(
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
"dtype"
,
qkv_b
_tensor
->
dtype
())));
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
wq
_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
{
ffn_eltadd1_b
->
Name
()});
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_layer_norm_out
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
tru
e
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
fals
e
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
{
ffn_act
->
Op
()
->
Type
()});
// Quantization attribute/Input
// parallel ring id
if
(
enable_int8
)
{
auto
*
c_identity_op
=
c_identity
->
Op
();
// Set input scale
fused_multi_transformer_op_desc
.
SetAttr
(
"ring_id"
,
std
::
string
qkv_input_name
=
matmul0_op
->
Input
(
"X"
)[
0
];
c_identity_op
->
GetAttr
(
"ring_id"
));
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"Input_scale_"
+
qkv_input_name
));
std
::
string
out_linear_input_name
=
matmul_linear_op
->
Input
(
"X"
)[
0
];
auto
out_linear_in_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"Input_scale_"
+
out_linear_input_name
));
std
::
string
ffn0_input_name
=
ffn_matmul_0_op
->
Input
(
"X"
)[
0
];
auto
ffn0_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"Input_scale_"
+
ffn0_input_name
));
std
::
string
ffn1_input_name
=
ffn_matmul_1_op
->
Input
(
"X"
)[
0
];
auto
ffn1_in_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"Input_scale_"
+
ffn1_input_name
));
// Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto
qkv_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul0_op
->
GetAttr
(
"weight_scale"
));
auto
out_weight_scale
=
PADDLE_GET_CONST
(
float
,
matmul_linear_op
->
GetAttr
(
"weight_scale"
));
auto
ffn0_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_0_op
->
GetAttr
(
"weight_scale"
));
auto
ffn1_weight_scale
=
PADDLE_GET_CONST
(
float
,
ffn_matmul_1_op
->
GetAttr
(
"weight_scale"
));
auto
qkv_out_scales
=
std
::
vector
<
float
>
(
3
*
dim_embed
,
(
qkv_weight_scale
/
127.0
f
)
*
(
qkv_in_scale
/
127.0
f
));
auto
out_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
out_weight_scale
/
127.0
f
)
*
(
out_linear_in_scale
/
127.0
f
));
auto
ffn0_out_scales
=
std
::
vector
<
float
>
(
4
*
dim_embed
,
(
ffn0_weight_scale
/
127.0
f
)
*
(
ffn0_in_scale
/
127.0
f
));
auto
ffn1_out_scales
=
std
::
vector
<
float
>
(
dim_embed
,
(
ffn1_weight_scale
/
127.0
f
)
*
(
ffn1_in_scale
/
127.0
f
));
// Inverse input scale
qkv_in_scale
=
1.0
f
/
qkv_in_scale
;
out_linear_in_scale
=
1.0
f
/
out_linear_in_scale
;
ffn0_in_scale
=
1.0
f
/
ffn0_in_scale
;
ffn1_in_scale
=
1.0
f
/
ffn1_in_scale
;
fused_multi_transformer_op_desc
.
SetAttr
(
"qkv_in_scale"
,
std
::
vector
<
float
>
{
qkv_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"out_linear_in_scale"
,
std
::
vector
<
float
>
{
out_linear_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn1_in_scale"
,
std
::
vector
<
float
>
{
ffn0_in_scale
});
fused_multi_transformer_op_desc
.
SetAttr
(
"ffn2_in_scale"
,
std
::
vector
<
float
>
{
ffn1_in_scale
});
auto
qkv_out_scale_var
=
scope
->
Var
(
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_var
=
scope
->
Var
(
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_var
=
scope
->
Var
(
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
qkv_out_scale_data
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
3
*
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
out_out_scale_data
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
ffn0_out_scale_data
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
4
*
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
ffn1_out_scale_data
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2OutScale"
,
{
ffn_matmul1_w
->
Name
()
+
"_out_scale"
});
}
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
if
(
enable_int8
)
{
auto
qkv_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul0_w
->
Name
()
+
"_out_scale"
);
auto
out_out_scale_node
=
CreatePersistableVarNode
(
graph
,
matmul_linear_w
->
Name
()
+
"_out_scale"
);
auto
ffn0_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul0_w
->
Name
()
+
"_out_scale"
);
auto
ffn1_out_scale_node
=
CreatePersistableVarNode
(
graph
,
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
IR_NODE_LINK_TO
(
qkv_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
out_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn0_out_scale_node
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn1_out_scale_node
,
fused_multi_transformer
);
}
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
...
@@ -2477,291 +3589,374 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2477,291 +3589,374 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_layer_norm_out
);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
IR_NODE_UNLINK
(
split0_k_out
,
while0
);
IR_NODE_UNLINK
(
split0_v_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
split0_k_out
);
IR_NODE_UNLINK
(
while0
,
split0_v_out
);
};
};
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
_fuse_qkv
"
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder
pass in
"
"
pass in
op compat failed."
;
"op compat failed."
;
return
;
return
;
}
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder(Fuse-QKV) fuse"
;
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
input0
,
input0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity0
,
c_identity0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity0_out
,
c_identity0_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity1
,
c_identity1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity1_out
,
c_identity1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity2
,
c_identity2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity2_out
,
c_identity2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity
,
ffn_c_identity
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity_out
,
ffn_c_identity_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul0
,
matmul0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul0_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv
_pattern
);
matmul0_w
,
matmul0_w
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv
_pattern
);
reshape2_0
,
reshape2_0
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
transpose2_0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul1
,
matmul1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
matmul1_w
,
matmul1_w
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
reshape2_1
,
reshape2_1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
scale_q
,
scale_q
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_q_out
,
scale_q_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul0_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
multi_devices_fused_multi_transformer_pattern
);
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_act
,
ffn_act
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_act_out
,
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_act_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
multi_devices_fused_multi_transformer_pattern
);
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_c_allreduce_sum
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum_out
,
ffn_c_allreduce_sum_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv
_pattern
)
multi_devices_fused_multi_transformer
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv
_pattern
)
ffn_output
,
ffn_output
,
multi_devices_fused_multi_transformer
_pattern
)
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd0
,
eltadd0
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd0_b
,
eltadd0_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
eltadd1
,
eltadd1
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd1_b
,
eltadd1_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
eltadd2
,
eltadd2
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd2_b
,
eltadd2_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qk
,
matmul_qk
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qk_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd_qk
,
eltadd_qk
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv
_pattern
);
softmax_qk
,
softmax_qk
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv
_pattern
);
matmul_qkv
,
matmul_qkv
,
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
reshape2_qkv
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_fuse_qkv
_pattern
);
multi_devices_fused_multi_transformer
_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
matmul_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
matmul_linear
,
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_fuse_qkv
_pattern
)
multi_devices_fused_multi_transformer
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum
,
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
c_allreduce_sum
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum_out
,
c_allreduce_sum_out
,
multi_devices_fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv
_pattern
)
multi_devices_fused_multi_transformer
_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
multi_devices_fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_fuse_qkv
_pattern
)
eltadd_out
,
eltadd_out
,
multi_devices_fused_multi_transformer
_pattern
)
fuse_creater
(
input0
,
fuse_creater
(
input0
,
c_identity0
,
layer_norm
,
layer_norm
,
layer_norm_scale
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
matmul0
,
matmul0_w
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd0_b
,
split0_k_out
,
eltadd1_b
,
split0_v_out
,
eltadd2_b
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_0
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
while0
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_matmul0
,
ffn_matmul0_w
,
ffn_matmul0_w
,
ffn_matmul1
,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_output
);
ffn_act
,
ffn_layer_norm_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
c_identity0
,
c_identity0_out
,
c_identity1
,
c_identity1_out
,
c_identity2
,
c_identity2_out
,
layer_norm
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
matmul0
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_0_out
,
split0
,
transpose2_1_out
,
split0_q
_out
,
transpose2_2
_out
,
s
plit0_k_out
,
s
cale_q
,
s
plit0_v
_out
,
s
cale_q
_out
,
matmul_qk
,
matmul_qk
,
matmul_qk_out
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk
,
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
...
@@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_out
,
matmul_linear_out
,
c_allreduce_sum
,
c_allreduce_sum_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_c_identity
,
ffn_c_identity_out
,
ffn_matmul0
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_matmul1_out
,
ffn_c_allreduce_sum
,
ffn_c_allreduce_sum_out
,
ffn_eltadd0
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_act
,
ffn_gelu_out
,
ffn_act_out
,
ffn_output
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
return
fusion_count
;
return
fusion_count
;
}
}
void
FusedMultiTransformerEncoderFuseQKVPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
MultiDevicesFusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
scope
,
scope
,
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"During the fused_multi_transformer_encoder pass, "
"During the multi_transformer pass, The scope should not be null."
));
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerEncoder
FuseQKV
Pass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
graph
->
Set
(
kFusedMultiTransformerEncoderFusionCount
,
new
int
(
fusion_count
));
}
}
AddStatis
(
fusion_count
);
AddStatis
(
fusion_count
);
}
}
FusedMultiTransformerEncoderFuseQKV
Pass
::
MultiDevicesFusedMultiTransformerEncoder
Pass
::
FusedMultiTransformerEncoderFuseQKV
Pass
()
{
MultiDevicesFusedMultiTransformerEncoder
Pass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
AddInput
(
"X"
)
.
IsTensor
()
.
IsTensor
()
...
@@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass::
...
@@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsNumGT
(
0
)
.
IsNumGT
(
0
)
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
IsTensor
()
...
@@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass::
...
@@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsType
<
std
::
vector
<
int
>>
()
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"
matmul
"
))
AddOpCompat
(
OpCompat
(
"
scale
"
))
.
AddInput
(
"X"
)
.
AddInput
(
"X"
)
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddAttr
(
"alpha"
)
.
AddAttr
(
"scale"
)
.
IsNumGE
(
0.0
f
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
IsNumLE
(
1.0
f
)
.
End
()
.
End
()
.
AddAttr
(
"
transpose_X
"
)
.
AddAttr
(
"
bias
"
)
.
Is
BoolEQ
(
false
)
.
Is
NumEQ
(
0.
f
)
.
End
()
.
End
()
.
AddAttr
(
"
transpose_Y"
)
.
AddAttr
(
"
bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
IsType
<
bool
>
()
.
End
();
.
End
();
...
@@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass::
...
@@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsType
<
bool
>
()
.
IsType
<
bool
>
()
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"while"
))
AddOpCompat
(
OpCompat
(
"relu"
))
.
AddInput
(
"X"
)
// A set of variables, unconstrained
.
AddInput
(
"X"
)
.
End
()
.
AddInput
(
"Condition"
)
// An scalar
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddOutput
(
"Out"
)
// A set of variables, unconstrained
.
AddOutput
(
"Out"
)
.
End
()
.
IsTensor
()
.
AddOutput
(
"StepScopes"
)
// A vector of local scope, unconstrained
.
End
()
.
AddAttr
(
"sub_block"
)
.
IsType
<
framework
::
BlockDesc
*>
()
.
End
();
.
End
();
}
}
...
@@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
@@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
@@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
// Quantization attribute/Input
// Quantization attribute/Input
if
(
enable_int8
)
{
if
(
enable_int8
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
// Set input scale
// Set input scale
std
::
string
matmul_input_scale_suffix
=
c_identity_op
->
Input
(
"X"
)[
0
];
std
::
string
matmul_input_scale_suffix
=
c_identity_op
->
Input
(
"X"
)[
0
];
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
auto
qkv_in_scale
=
PADDLE_GET_CONST
(
...
@@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto
ffn1_out_scale_var
=
auto
ffn1_out_scale_var
=
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
scope
->
Var
(
ffn_matmul1_w
->
Name
()
+
"_out_scale"
);
auto
qkv_out_scale_data
=
auto
*
qkv_out_scale_tensor
=
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
qkv_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
3
*
dim_embed
},
platform
::
CPUPlace
());
qkv_out_scale_tensor
->
Resize
({
3
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
qkv_out_scale_tensor
);
auto
qkv_out_scale_data
=
qkv_out_scale_tensor
->
data
<
float
>
();
memcpy
(
qkv_out_scale_data
,
memcpy
(
qkv_out_scale_data
,
qkv_out_scales
.
data
(),
qkv_out_scales
.
data
(),
qkv_out_scales
.
size
()
*
sizeof
(
float
));
qkv_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
"QKVOutScale"
,
{
matmul0_w
->
Name
()
+
"_out_scale"
});
auto
out_out_scale_data
=
auto
*
out_out_scale_tensor
=
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
out_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
out_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
out_out_scale_tensor
);
auto
out_out_scale_data
=
out_out_scale_tensor
->
data
<
float
>
();
memcpy
(
out_out_scale_data
,
memcpy
(
out_out_scale_data
,
out_out_scales
.
data
(),
out_out_scales
.
data
(),
out_out_scales
.
size
()
*
sizeof
(
float
));
out_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
"OutLinearOutScale"
,
{
matmul_linear_w
->
Name
()
+
"_out_scale"
});
auto
ffn0_out_scale_data
=
auto
*
ffn0_out_scale_tensor
=
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn0_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
4
*
dim_embed
},
platform
::
CPUPlace
());
ffn0_out_scale_tensor
->
Resize
({
4
*
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn0_out_scale_tensor
);
auto
ffn0_out_scale_data
=
ffn0_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn0_out_scale_data
,
memcpy
(
ffn0_out_scale_data
,
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
data
(),
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
ffn0_out_scales
.
size
()
*
sizeof
(
float
));
fused_multi_transformer_op_desc
.
SetInput
(
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
"FFN1OutScale"
,
{
ffn_matmul0_w
->
Name
()
+
"_out_scale"
});
auto
ffn1_out_scale_data
=
auto
*
ffn1_out_scale_tensor
=
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
()
ffn1_out_scale_var
->
GetMutable
<
phi
::
DenseTensor
>
();
->
mutable_data
<
float
>
({
dim_embed
},
platform
::
CPUPlace
());
ffn1_out_scale_tensor
->
Resize
({
dim_embed
});
dev_ctx
->
Alloc
<
float
>
(
ffn1_out_scale_tensor
);
auto
ffn1_out_scale_data
=
ffn1_out_scale_tensor
->
data
<
float
>
();
memcpy
(
ffn1_out_scale_data
,
memcpy
(
ffn1_out_scale_data
,
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
data
(),
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
ffn1_out_scales
.
size
()
*
sizeof
(
float
));
...
@@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
@@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
act
,
ffn_
gelu
_out
,
ffn_
act
_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass,
...
@@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass,
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderPass
);
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderPass
);
REGISTER_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
,
REGISTER_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderFuseQKVPass
);
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderFuseQKVPass
);
REGISTER_PASS
(
multi_devices_fused_multi_transformer_encoder_pass
,
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderPass
);
REGISTER_PASS
(
REGISTER_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
,
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
);
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
);
...
@@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
...
@@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
.
LE
(
"matmul"
,
1
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
multi_devices_fused_multi_transformer_encoder_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
REGISTER_PASS_CAPABILITY
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
)
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
)
.
AddCombination
(
.
AddCombination
(
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
浏览文件 @
29eec2dd
...
@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul2
);
...
@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk
);
...
@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
PATTERN_DECL_NODE
(
attention_output
);
// while loop
// post layer_norm
PATTERN_DECL_NODE
(
while0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
};
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
...
@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_output
);
};
};
struct
MultiDevicesFusedMultiTransformerEncoderPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multi_devices_fused_multi_transformer_encoder"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
c_identity0
);
PATTERN_DECL_NODE
(
c_identity0_out
);
PATTERN_DECL_NODE
(
c_identity1
);
PATTERN_DECL_NODE
(
c_identity1_out
);
PATTERN_DECL_NODE
(
c_identity2
);
PATTERN_DECL_NODE
(
c_identity2_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul1_w
);
PATTERN_DECL_NODE
(
matmul2_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
matmul1_out
);
PATTERN_DECL_NODE
(
matmul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
c_allreduce_sum
);
PATTERN_DECL_NODE
(
c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// post layer_norm
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_c_identity
);
PATTERN_DECL_NODE
(
ffn_c_identity_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_act
);
PATTERN_DECL_NODE
(
ffn_act_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
(
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
(
...
@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...
@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
...
@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope
*
scope
)
const
;
Scope
*
scope
)
const
;
};
};
class
MultiDevicesFusedMultiTransformerEncoderPass
:
public
FusePassBase
{
public:
MultiDevicesFusedMultiTransformerEncoderPass
();
virtual
~
MultiDevicesFusedMultiTransformerEncoderPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multi_devices_fused_multi_transformer_encoder"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
:
public
FusePassBase
{
:
public
FusePassBase
{
public:
public:
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
浏览文件 @
29eec2dd
...
@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
...
@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
// FFN: fc1 -> (gelu) -> fc2
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias
_
0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias
_
1"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias1"
,
{
1024
});
return
param_scope
;
return
param_scope
;
}
}
...
@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
...
@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
TEST
(
FusedMultiTransformerEncoderPass
,
basic
)
{
TEST
(
FusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// inputs operator output
// --------------------------------------------------------------------
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (x, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (x, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (x, weights_2) matmul_v2 -> matmul_out2
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
...
@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (reshape_0) transpose2 -> transpose_0
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
...
@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_
out)
elementwise_add -> attention_out
// (eltadd_
linear)
elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm ->
ffn_
layer_norm_out
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
//
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
Layers
layers
;
// MHA: pre LayerNorm
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
x
,
weights_0
,
nullptr
,
false
,
false
);
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
x
,
weights_1
,
nullptr
,
false
,
false
);
auto
*
matmul_out_1
=
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
x
,
weights_2
,
nullptr
,
false
,
false
);
layers
.
matmul_v2
(
ln_out
,
weights_1
,
nullptr
,
false
,
true
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
ln_out
,
weights_2
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
...
@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// Link to decoder while block
// q scale
layers
.
while_loop
({
transpose_1
,
transpose_2
});
auto
*
scale_q
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
// MHA: QK matmul
// MHA: QK matmul
auto
*
matmul_qk
=
auto
*
matmul_qk
=
layers
.
matmul
(
transpose_0
,
transpose_1
,
nullptr
,
false
,
true
);
layers
.
matmul
_v2
(
scale_q
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// MHA: out Linear
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"
weightsl"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"
bias_l"
,
{
1024
},
true
);
auto
*
linear_matmut_out
=
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
tru
e
);
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
fals
e
);
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
// FFN: pre LayerNorm
// post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
auto
*
ln_out
=
layers
.
layer_norm
(
attention_out
,
ln_scale
,
ln_bias
)[
0
];
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
...
@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_
ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
layers
.
matmul_v2
(
ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
...
@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_out
=
layers
.
elementwise_add
(
ln_out
,
ffn_eltadd1_out
);
// FFN: post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
layers
.
layer_norm
(
ffn_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
5
6
,
num_nodes_after
+
5
8
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"node num in graph "
"should be %d, but the result is %d"
,
"should be %d, but the result is %d"
,
num_nodes_before
-
5
6
,
num_nodes_before
-
5
8
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
...
@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
.
IsPassCompatible
(
"fused_multi_transformer_encoder_pass"
));
.
IsPassCompatible
(
"fused_multi_transformer_encoder_pass"
));
}
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x) c_identity -> c_identity0_out
// (x) c_identity -> c_identity1_out
// (x) c_identity -> c_identity2_out
// (c_identity0_out, weights_0) matmul_v2 -> matmul_out0
// (c_identity1_out, weights_1) matmul_v2 -> matmul_out1
// (c_identity2_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
c_identity0_out
=
layers
.
c_identity
(
x
);
auto
*
c_identity1_out
=
layers
.
c_identity
(
x
);
auto
*
c_identity2_out
=
layers
.
c_identity
(
x
);
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
c_identity0_out
,
weights_0
,
nullptr
,
false
,
false
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
c_identity1_out
,
weights_1
,
nullptr
,
false
,
false
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
c_identity2_out
,
weights_2
,
nullptr
,
false
,
false
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
1024
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
matmul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
matmul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// q scale
auto
*
scale_q
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul_v2
(
scale_q
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"bias_l"
,
{
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
false
);
auto
*
c_allreduce_out
=
layers
.
c_allreduce_sum
(
linear_matmut_out
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
// post LayerNorm
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
attention_out
,
ln_scale
,
ln_bias
)[
0
];
auto
*
ffn_c_identity_out
=
layers
.
c_identity
(
ln_out
);
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_c_identity_out
,
ffn_weights0
,
nullptr
,
false
,
false
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
false
);
auto
*
ffn_allreduce_out
=
layers
.
c_allreduce_sum
(
ffn_matmul1_out
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
auto
*
ffn_out
=
layers
.
elementwise_add
(
ln_out
,
ffn_eltadd1_out
);
// FFN: post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
layers
.
layer_norm
(
ffn_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"enable_int8"
,
new
bool
(
false
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"multi_devices_fused_multi_transformer_encoder_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get multi_devices_fused_multi_transformer_encoder_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
70
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
70
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multi_devices_fused_multi_transformer_encoder_pass"
));
}
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
// inputs operator output
// inputs operator output
// --------------------------------------------------------------------
// --------------------------------------------------------------------
...
@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
...
@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
USE_PASS
(
fused_multi_transformer_encoder_pass
);
USE_PASS
(
fused_multi_transformer_encoder_pass
);
USE_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
29eec2dd
...
@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
...
@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fuse_multi_transformer_layer_pass"
,
"fuse_multi_transformer_layer_pass"
,
...
@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
...
@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass"
,
//
"fused_multi_transformer_decoder_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"fuse_multi_transformer_layer_pass"
,
//
"fuse_multi_transformer_layer_pass"
,
//
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录