Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9ad0e37e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9ad0e37e
编写于
11月 01, 2022
作者:
K
Kaipeng Deng
提交者:
GitHub
11月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix memory copy in prepare_data of FusedMultiTransformer pass (#47306)
* fix memory copy in prepare_data. test=develop
上级
8a1124b1
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
154 addition
and
580 deletion
+154
-580
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+51
-227
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+0
-18
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc
...amework/ir/fused_multi_transformer_decoder_pass_tester.cc
+21
-45
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+60
-227
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+0
-18
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+21
-45
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
9ad0e37e
...
@@ -237,15 +237,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -237,15 +237,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
concat_0_out_var
})
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
concat_0_out_var
})
...
@@ -253,7 +245,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -253,7 +245,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -294,14 +285,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -294,14 +285,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
@@ -310,7 +294,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -310,7 +294,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout
_qk_out_var
,
concat_1_out_var
})
matmul_qkv
->
LinksFrom
({
softmax
_qk_out_var
,
concat_1_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -320,9 +304,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -320,9 +304,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.
LinksTo
({
matmul_linear_out_var
});
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// Feed Forward LayerNorm Nodes
...
@@ -358,7 +340,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -358,7 +340,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
ffn_layer_norm_mean_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2
-> dropout
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
...
@@ -403,13 +385,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -403,13 +385,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
auto
*
ffn_eltadd_out
=
...
@@ -427,9 +402,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
...
@@ -427,9 +402,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
dropout
_out_var
})
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
eltadd1
_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
return
ffn_output
;
...
@@ -575,15 +549,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -575,15 +549,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
...
@@ -591,7 +557,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -591,7 +557,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -632,14 +597,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -632,14 +597,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
@@ -648,7 +606,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -648,7 +606,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout
_qk_out_var
,
concat_v_out_var
})
matmul_qkv
->
LinksFrom
({
softmax
_qk_out_var
,
concat_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -658,9 +616,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -658,9 +616,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
matmul_linear_out_var
});
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// Feed Forward LayerNorm Nodes
...
@@ -696,7 +652,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -696,7 +652,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
ffn_layer_norm_mean_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2
-> dropout
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
...
@@ -741,13 +697,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -741,13 +697,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
auto
*
ffn_eltadd_out
=
...
@@ -765,9 +714,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -765,9 +714,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
dropout
_out_var
})
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
eltadd1
_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
return
ffn_output
;
...
@@ -922,15 +870,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -922,15 +870,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
...
@@ -938,7 +878,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -938,7 +878,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -987,14 +926,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -987,14 +926,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
@@ -1003,7 +935,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1003,7 +935,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout
_qk_out_var
,
concat_v_out_var
})
matmul_qkv
->
LinksFrom
({
softmax
_qk_out_var
,
concat_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -1015,9 +947,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1015,9 +947,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
c_allreduce_sum_out_var
});
.
LinksTo
({
c_allreduce_sum_out_var
});
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// Feed Forward LayerNorm Nodes
...
@@ -1063,7 +993,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1063,7 +993,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2
-> dropout
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
...
@@ -1117,13 +1047,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1117,13 +1047,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
auto
*
ffn_eltadd_out
=
...
@@ -1143,9 +1066,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
...
@@ -1143,9 +1066,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
dropout
_out_var
})
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
eltadd1
_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
return
ffn_output
;
...
@@ -1180,11 +1102,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1180,11 +1102,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node
*
transpose2_1_out
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_bias
,
...
@@ -1194,7 +1114,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1194,7 +1114,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
// Calc index of transformer layer by LayerNorm Scale name
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// This calculation assumes:
...
@@ -1287,14 +1206,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1287,14 +1206,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
...
@@ -1313,6 +1226,15 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1313,6 +1226,15 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
};
};
...
@@ -1451,11 +1373,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1451,11 +1373,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
...
@@ -1499,10 +1416,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1499,10 +1416,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
...
@@ -1531,10 +1444,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1531,10 +1444,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
...
@@ -1554,11 +1463,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1554,11 +1463,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
transpose2_1_out
,
transpose2_1_out
,
transpose2_2_out
,
transpose2_2_out
,
eltadd_qk_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
reshape2_0
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
dropout_linear
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
...
@@ -1568,12 +1475,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1568,12 +1475,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
...
@@ -1613,8 +1517,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1613,8 +1517,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
softmax_qk_out
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv
,
...
@@ -1623,17 +1525,11 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1623,17 +1525,11 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
...
@@ -1647,8 +1543,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1647,8 +1543,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu
,
ffn_gelu_out
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -1850,11 +1744,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1850,11 +1744,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
matmul0_w
,
Node
*
matmul0_w
,
Node
*
eltadd0_b
,
Node
*
eltadd0_b
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_bias
,
...
@@ -1864,7 +1756,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1864,7 +1756,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
// Calc index of transformer layer by LayerNorm Scale name
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// This calculation assumes:
...
@@ -1957,17 +1848,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1957,17 +1848,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
...
@@ -1986,6 +1868,15 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1986,6 +1868,15 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
};
};
...
@@ -2116,12 +2007,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2116,12 +2007,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -2153,11 +2038,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2153,11 +2038,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -2193,12 +2073,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2193,12 +2073,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -2212,11 +2086,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2212,11 +2086,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
matmul0_w
,
matmul0_w
,
eltadd0_b
,
eltadd0_b
,
eltadd_qk_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
reshape2_0
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
dropout_linear
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
...
@@ -2226,12 +2098,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2226,12 +2098,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
...
@@ -2261,8 +2130,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2261,8 +2130,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
softmax_qk_out
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv
,
...
@@ -2271,17 +2138,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2271,17 +2138,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
...
@@ -2295,8 +2156,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2295,8 +2156,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu
,
ffn_gelu_out
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2500,11 +2359,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2500,11 +2359,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
matmul0_w
,
Node
*
matmul0_w
,
Node
*
eltadd0_b
,
Node
*
eltadd0_b
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_bias
,
...
@@ -2514,7 +2371,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2514,7 +2371,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
// Calc index of transformer layer by LayerNorm Scale name
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// This calculation assumes:
...
@@ -2607,23 +2463,14 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2607,23 +2463,14 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
// parallel ring id
// parallel ring id
auto
*
c_identity_op
=
c_identity
->
Op
();
auto
*
c_identity_op
=
c_identity
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"ring_id"
,
fused_multi_transformer_op_desc
.
SetAttr
(
"ring_id"
,
c_identity_op
->
GetAttr
(
"ring_id"
));
c_identity_op
->
GetAttr
(
"ring_id"
));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
...
@@ -2641,6 +2488,15 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2641,6 +2488,15 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
};
};
...
@@ -2790,12 +2646,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2790,12 +2646,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -2827,11 +2677,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2827,11 +2677,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -2873,12 +2718,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2873,12 +2718,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -2893,11 +2732,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2893,11 +2732,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
matmul0_w
,
matmul0_w
,
eltadd0_b
,
eltadd0_b
,
eltadd_qk_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
reshape2_0
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
dropout_linear
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
...
@@ -2907,12 +2744,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2907,12 +2744,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
...
@@ -2944,8 +2778,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2944,8 +2778,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
softmax_qk_out
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv
,
...
@@ -2954,19 +2786,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2954,19 +2786,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
matmul_linear_out
,
c_allreduce_sum
,
c_allreduce_sum
,
c_allreduce_sum_out
,
c_allreduce_sum_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
...
@@ -2984,8 +2810,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2984,8 +2810,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu
,
ffn_gelu_out
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
浏览文件 @
9ad0e37e
...
@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
...
@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv
);
...
@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
...
@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
...
@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
...
@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
...
@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
...
@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv
);
...
@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
...
@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
...
@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
...
@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
...
@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
...
@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv
);
...
@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
...
@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
...
@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
...
@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc
浏览文件 @
9ad0e37e
...
@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
...
@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
// (transpose_0, transpose_1) matmul -> matmul_qk
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_out) elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...
@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
...
@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
Layers
layers
;
Layers
layers
;
// MHA: pre LayerNorm
// MHA: pre LayerNorm
...
@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
...
@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout
_qk
,
concat_v
);
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax
_qk
,
concat_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
...
@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
...
@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
...
@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
...
@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
...
@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
72
,
num_nodes_after
+
60
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_pass, The "
"After the fused_multi_transformer_decoder_pass, The "
"node num in graph "
"node num in graph "
"should be %d, but the result is %d"
,
"should be %d, but the result is %d"
,
num_nodes_before
-
72
,
num_nodes_before
-
60
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_out) elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...
@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
//
// (transpose_1, transpose_2) while -> decoder block
// (transpose_1, transpose_2) while -> decoder block
...
@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout
_qk
,
concat_v
);
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax
_qk
,
concat_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
...
@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
...
@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_before
,
num_nodes_after
+
62
,
num_nodes_after
+
50
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
62
,
num_nodes_before
-
50
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_out) elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...
@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
//
// (transpose_1, transpose_2) while -> decoder block
// (transpose_1, transpose_2) while -> decoder block
...
@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout
_qk
,
concat_v
);
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax
_qk
,
concat_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
...
@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
...
@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_c_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_c_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
...
@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_before
,
num_nodes_after
+
70
,
num_nodes_after
+
58
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
70
,
num_nodes_before
-
58
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
9ad0e37e
...
@@ -227,15 +227,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -227,15 +227,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
transpose2_1_out_var
})
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
transpose2_1_out_var
})
...
@@ -243,7 +235,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -243,7 +235,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -284,14 +275,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -284,14 +275,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
@@ -300,7 +284,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -300,7 +284,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout
_qk_out_var
,
transpose2_2_out_var
})
matmul_qkv
->
LinksFrom
({
softmax
_qk_out_var
,
transpose2_2_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -310,9 +294,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -310,9 +294,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
.
LinksTo
({
matmul_linear_out_var
});
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// while loop
// while loop
...
@@ -352,7 +334,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -352,7 +334,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
ffn_layer_norm_mean_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2
-> dropout
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
...
@@ -397,13 +379,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -397,13 +379,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
auto
*
ffn_eltadd_out
=
...
@@ -421,9 +396,8 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
...
@@ -421,9 +396,8 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
dropout
_out_var
})
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
eltadd1
_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
return
ffn_output
;
...
@@ -545,15 +519,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -545,15 +519,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
...
@@ -561,7 +527,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -561,7 +527,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -602,14 +567,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -602,14 +567,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
@@ -618,7 +576,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -618,7 +576,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout
_qk_out_var
,
split0_v_out_var
})
matmul_qkv
->
LinksFrom
({
softmax
_qk_out_var
,
split0_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -628,9 +586,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -628,9 +586,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.
LinksTo
({
matmul_linear_out_var
});
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// Feed Forward LayerNorm Nodes
...
@@ -666,7 +622,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -666,7 +622,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
ffn_layer_norm_mean_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2
-> dropout
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
...
@@ -711,13 +667,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -711,13 +667,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
auto
*
ffn_eltadd_out
=
...
@@ -735,9 +684,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -735,9 +684,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul1_out_var
});
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
dropout
_out_var
})
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
eltadd1
_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
return
ffn_output
;
...
@@ -868,15 +816,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -868,15 +816,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
...
@@ -884,7 +824,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -884,7 +824,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
// QKV path Nodes
auto
*
matmul_qkv
=
auto
*
matmul_qkv
=
...
@@ -933,14 +872,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -933,14 +872,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
@@ -949,7 +881,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -949,7 +881,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->
AsIntermediate
();
->
AsIntermediate
();
// QKV path Links
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout
_qk_out_var
,
split0_v_out_var
})
matmul_qkv
->
LinksFrom
({
softmax
_qk_out_var
,
split0_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
.
LinksTo
({
transpose2_qkv_out_var
});
...
@@ -961,9 +893,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -961,9 +893,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.
LinksTo
({
c_allreduce_sum_out_var
});
.
LinksTo
({
c_allreduce_sum_out_var
});
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
eltadd_out
->
LinksFrom
({
input0
,
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
// Feed Forward LayerNorm Nodes
...
@@ -1009,7 +939,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -1009,7 +939,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2
-> dropout
// Feed Forward fc1 -> gelu -> fc2
auto
*
ffn_matmul0
=
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
...
@@ -1063,13 +993,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -1063,13 +993,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
auto
*
ffn_eltadd_out
=
...
@@ -1089,9 +1012,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
...
@@ -1089,9 +1012,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
dropout
_out_var
})
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_
eltadd1
_out_var
})
.
LinksTo
({
ffn_output
});
.
LinksTo
({
ffn_output
});
return
ffn_output
;
return
ffn_output
;
...
@@ -1253,11 +1175,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1253,11 +1175,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
transpose2_1_out
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
while0
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
...
@@ -1268,7 +1188,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1268,7 +1188,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
int
num_head
=
...
@@ -1375,7 +1294,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1375,7 +1294,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
));
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
wq_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -1409,15 +1330,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1409,15 +1330,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
...
@@ -1433,6 +1347,15 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1433,6 +1347,15 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
// rewrite while OP input
// rewrite while OP input
...
@@ -1620,11 +1543,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1620,11 +1543,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
...
@@ -1668,11 +1586,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1668,11 +1586,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
...
@@ -1700,11 +1613,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1700,11 +1613,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
...
@@ -1723,11 +1631,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1723,11 +1631,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
transpose2_1_out
,
transpose2_1_out
,
transpose2_2_out
,
transpose2_2_out
,
eltadd_qk_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
reshape2_0
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
dropout_linear
,
while0
,
while0
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
...
@@ -1738,12 +1644,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1738,12 +1644,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
...
@@ -1777,8 +1680,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1777,8 +1680,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
softmax_qk_out
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv
,
...
@@ -1787,17 +1688,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1787,17 +1688,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
...
@@ -1811,8 +1706,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
...
@@ -1811,8 +1706,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu
,
ffn_gelu_out
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2016,11 +1909,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2016,11 +1909,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
split0_k_out
,
Node
*
split0_k_out
,
Node
*
split0_v_out
,
Node
*
split0_v_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
while0
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
...
@@ -2031,7 +1922,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2031,7 +1922,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
int
num_head
=
...
@@ -2104,7 +1994,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2104,7 +1994,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
));
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
qkv_w_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -2139,14 +2031,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2139,14 +2031,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
auto
*
fused_multi_transformer
=
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
...
@@ -2162,6 +2048,15 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2162,6 +2048,15 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
// rewrite while OP input
// rewrite while OP input
...
@@ -2315,12 +2210,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2315,12 +2210,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -2352,11 +2241,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2352,11 +2241,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -2392,12 +2276,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2392,12 +2276,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -2416,11 +2294,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2416,11 +2294,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_k_out
,
split0_k_out
,
split0_v_out
,
split0_v_out
,
eltadd_qk_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
reshape2_0
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
dropout_linear
,
while0
,
while0
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
...
@@ -2431,12 +2307,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2431,12 +2307,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
...
@@ -2458,8 +2331,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2458,8 +2331,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
softmax_qk_out
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv
,
...
@@ -2468,17 +2339,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2468,17 +2339,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
...
@@ -2492,8 +2357,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2492,8 +2357,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu
,
ffn_gelu_out
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
@@ -2700,11 +2563,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2700,11 +2563,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
split0_k_out
,
Node
*
split0_k_out
,
Node
*
split0_v_out
,
Node
*
split0_v_out
,
Node
*
eltadd_qk_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
while0
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_scale
,
...
@@ -2715,7 +2576,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2715,7 +2576,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
int
num_head
=
...
@@ -2789,7 +2649,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2789,7 +2649,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
));
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
qkv_w_tensor
->
dtype
())));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
...
@@ -2824,14 +2686,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2824,14 +2686,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
fused_multi_transformer_op_desc
.
SetAttr
(
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
// parallel ring id
// parallel ring id
auto
*
c_identity_op
=
c_identity
->
Op
();
auto
*
c_identity_op
=
c_identity
->
Op
();
...
@@ -2852,6 +2708,15 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -2852,6 +2708,15 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_matmul1_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
ffn_eltadd1_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
// rewrite while OP input
// rewrite while OP input
...
@@ -3024,12 +2889,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3024,12 +2889,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fused_multi_transformer_fuse_qkv_pattern
)
...
@@ -3061,11 +2920,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3061,11 +2920,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -3107,12 +2961,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3107,12 +2961,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
);
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
);
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
);
...
@@ -3132,11 +2980,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3132,11 +2980,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_k_out
,
split0_k_out
,
split0_v_out
,
split0_v_out
,
eltadd_qk_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
reshape2_0
,
matmul_linear_w
,
matmul_linear_w
,
eltadd_linear_b
,
eltadd_linear_b
,
dropout_linear
,
while0
,
while0
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
...
@@ -3147,12 +2993,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3147,12 +2993,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_variance
,
layer_norm_out
,
layer_norm_out
,
...
@@ -3176,8 +3019,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3176,8 +3019,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
eltadd_qk_out
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk
,
softmax_qk_out
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv
,
...
@@ -3186,19 +3027,13 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3186,19 +3027,13 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
matmul_linear_out
,
c_allreduce_sum
,
c_allreduce_sum
,
c_allreduce_sum_out
,
c_allreduce_sum_out
,
eltadd_linear
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_layer_norm_out
,
...
@@ -3216,8 +3051,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
...
@@ -3216,8 +3051,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu
,
ffn_gelu_out
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
ffn_eltadd_out
});
// Remove unneeded nodes.
// Remove unneeded nodes.
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
浏览文件 @
9ad0e37e
...
@@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv
);
...
@@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
...
@@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
...
@@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
...
@@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv
);
...
@@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
...
@@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
...
@@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
...
@@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...
@@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv
);
...
@@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...
@@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
eltadd_out
)
...
@@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...
@@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
浏览文件 @
9ad0e37e
...
@@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_0, transpose_1) matmul -> matmul_qk
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_out) elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...
@@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
//
// (transpose_1, transpose_2) while -> decoder block
// (transpose_1, transpose_2) while -> decoder block
...
@@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout
_qk
,
transpose_2
);
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax
_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
...
@@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
...
@@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
...
@@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
68
,
num_nodes_after
+
56
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"node num in graph "
"should be %d, but the result is %d"
,
"should be %d, but the result is %d"
,
num_nodes_before
-
68
,
num_nodes_before
-
56
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
@@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_out) elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...
@@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
//
// (transpose_1, transpose_2) while -> decoder block
// (transpose_1, transpose_2) while -> decoder block
...
@@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout
_qk
,
split_v
);
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax
_qk
,
split_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
...
@@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
...
@@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_before
,
num_nodes_after
+
56
,
num_nodes_after
+
44
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
56
,
num_nodes_before
-
44
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
@@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_out) elementwise_add -> attention_out
//
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...
@@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
//
// (transpose_1, transpose_2) while -> decoder block
// (transpose_1, transpose_2) while -> decoder block
...
@@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout
_qk
,
split_v
);
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax
_qk
,
split_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
...
@@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
linear_eltadd_out
=
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
...
@@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
ffn_eltadd1_out
=
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
ffn_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
@@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
...
@@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_before
,
num_nodes_after
+
64
,
num_nodes_after
+
52
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
64
,
num_nodes_before
-
52
,
num_nodes_after
));
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
1
,
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
9ad0e37e
...
@@ -39,6 +39,7 @@ namespace ir {
...
@@ -39,6 +39,7 @@ namespace ir {
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
std
::
vector
<
std
::
string
>
support_subgraph_passes
=
{
static
const
std
::
vector
<
std
::
string
>
support_subgraph_passes
=
{
"simplify_with_basic_ops_pass"
,
"fused_multi_transformer_encoder_pass"
,
"fused_multi_transformer_encoder_pass"
,
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录