Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8d4450db
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8d4450db
编写于
12月 12, 2022
作者:
R
RichardWooSJTU
提交者:
GitHub
12月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update fused_multi_transformer_encoder_pass support GPT new matmul API (#48953)
* fit paddle.matmul in fleetx.gpt
上级
41d27818
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
193 addition
and
52 deletion
+193
-52
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+74
-12
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+4
-0
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc
...amework/ir/fused_multi_transformer_decoder_pass_tester.cc
+17
-13
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+76
-12
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+4
-0
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+18
-15
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
8d4450db
...
...
@@ -478,7 +478,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
->
assert_is_op_input
(
"matmul
_v2
"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
...
...
@@ -496,7 +496,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
concat_k_out_var
=
pattern
->
NewNode
(
concat_k_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
)
->
assert_is_op_input
(
"matmul
_v2
"
)
->
assert_is_op_input
(
"assign"
);
auto
*
concat_v_in_var
=
pattern
->
NewNode
(
concat_v_in_repr
())
...
...
@@ -529,10 +529,16 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v
->
LinksFrom
({
concat_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
...
@@ -554,7 +560,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
...
...
@@ -799,7 +806,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
->
assert_is_op_input
(
"matmul
_v2
"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
...
...
@@ -817,7 +824,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
concat_k_out_var
=
pattern
->
NewNode
(
concat_k_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
)
->
assert_is_op_input
(
"matmul
_v2
"
)
->
assert_is_op_input
(
"assign"
);
auto
*
concat_v_in_var
=
pattern
->
NewNode
(
concat_v_in_repr
())
...
...
@@ -850,10 +857,16 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v
->
LinksFrom
({
concat_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
...
@@ -875,7 +888,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
...
...
@@ -2192,6 +2206,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
...
...
@@ -2296,6 +2315,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v
,
matmul_qk
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
...
...
@@ -2382,6 +2403,23 @@ FusedMultiTransformerDecoderFuseQKVPass::
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
...
...
@@ -2917,6 +2955,11 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
...
...
@@ -3031,6 +3074,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v
,
matmul_qk
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
...
...
@@ -3124,6 +3169,23 @@ MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
浏览文件 @
8d4450db
...
...
@@ -182,6 +182,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
scale_qk
);
PATTERN_DECL_NODE
(
scale_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
...
...
@@ -282,6 +284,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
scale_qk
);
PATTERN_DECL_NODE
(
scale_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc
浏览文件 @
8d4450db
...
...
@@ -243,8 +243,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
...
...
@@ -298,10 +299,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
layers
.
assign
(
concat_v
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul
_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale
_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
...
...
@@ -361,11 +363,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
5
0
,
num_nodes_after
+
5
2
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
5
0
,
num_nodes_before
-
5
2
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
...
...
@@ -396,8 +398,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
...
...
@@ -455,10 +458,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
layers
.
assign
(
concat_v
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul
_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale
_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
...
...
@@ -523,11 +527,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
58
,
num_nodes_after
+
60
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
58
,
num_nodes_before
-
60
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
8d4450db
...
...
@@ -472,11 +472,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
->
assert_is_op_input
(
"matmul
_v2
"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"matmul
_v2
"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
...
...
@@ -499,10 +499,17 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
...
@@ -524,7 +531,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
...
...
@@ -769,11 +777,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
->
assert_is_op_input
(
"matmul
_v2
"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"matmul
_v2
"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
...
...
@@ -796,10 +804,17 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale_qk
=
pattern
->
NewNode
(
scale_qk_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_qk_out_var
=
pattern
->
NewNode
(
scale_qk_out_repr
())
->
assert_is_op_output
(
"scale"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
...
@@ -821,7 +836,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
scale_qk
->
LinksFrom
({
matmul_qk_out_var
}).
LinksTo
({
scale_qk_out_var
});
eltadd_qk
->
LinksFrom
({
scale_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
...
...
@@ -2637,6 +2653,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
...
...
@@ -2739,6 +2760,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_v_out
,
matmul_qk
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
...
...
@@ -2826,6 +2849,23 @@ FusedMultiTransformerEncoderFuseQKVPass::
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
...
...
@@ -3468,6 +3508,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk
,
scale_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_qk_out
,
scale_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
...
...
@@ -3580,6 +3625,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_v_out
,
matmul_qk
,
matmul_qk_out
,
scale_qk
,
scale_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
...
...
@@ -3675,6 +3722,23 @@ MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
浏览文件 @
8d4450db
...
...
@@ -168,6 +168,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
scale_qk
);
PATTERN_DECL_NODE
(
scale_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
...
...
@@ -263,6 +265,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
scale_qk
);
PATTERN_DECL_NODE
(
scale_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
浏览文件 @
8d4450db
...
...
@@ -236,9 +236,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul
-> matmul_qk
// (matmul_qk
, bias_qk) elementwise_add -> eltadd
_qk
// (
eltadd_qk)
softmax -> softmax_qk
// (split_q, split_k) matmul
_v2
-> matmul_qk
// (matmul_qk
) scale -> scale
_qk
// (
scale_qk, eltadd_qk)
softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
...
...
@@ -289,10 +289,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
layers
.
while_loop
({
split_k
,
split_v
});
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul
_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale
_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
...
...
@@ -352,11 +353,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
4
4
,
num_nodes_after
+
4
6
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
4
4
,
num_nodes_before
-
4
6
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
...
...
@@ -385,8 +386,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
...
...
@@ -442,10 +444,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
layers
.
while_loop
({
split_k
,
split_v
});
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul
_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale
_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
...
...
@@ -510,11 +513,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
5
2
,
num_nodes_after
+
5
4
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
5
2
,
num_nodes_before
-
5
4
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录