Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
29eec2dd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29eec2dd
编写于
1月 04, 2023
作者:
L
lzy
提交者:
GitHub
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)
上级
a2d7e1d7
变更
6
展开全部
显示空白变更内容
内联
并排
Showing
6 changed file
with
2776 addition
and
1241 deletion
+2776
-1241
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+50
-36
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+6
-6
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+2341
-1140
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+164
-20
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+213
-39
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
29eec2dd
...
...
@@ -31,6 +31,8 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
static
const
std
::
unordered_set
<
std
::
string
>
FFN_ACTS
{
"relu"
,
"gelu"
};
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
...
...
@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
...
...
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
...
...
@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
...
...
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
...
...
@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
...
...
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
...
...
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
...
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
...
...
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
...
...
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
...
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
gelu
_out
,
ffn_
act
,
ffn_
act
_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
...
...
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
...
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
...
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
...
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
...
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
gelu
_out
,
ffn_
act
,
ffn_
act
_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
...
...
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
...
...
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
...
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
...
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
...
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
gelu
_out
,
ffn_
act
,
ffn_
act
_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
浏览文件 @
29eec2dd
...
...
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
29eec2dd
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
浏览文件 @
29eec2dd
...
...
@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
...
...
@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
...
...
@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// while loop
PATTERN_DECL_NODE
(
while0
);
// post layer_norm
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
...
...
@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_output
);
};
struct
MultiDevicesFusedMultiTransformerEncoderPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multi_devices_fused_multi_transformer_encoder"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
c_identity0
);
PATTERN_DECL_NODE
(
c_identity0_out
);
PATTERN_DECL_NODE
(
c_identity1
);
PATTERN_DECL_NODE
(
c_identity1_out
);
PATTERN_DECL_NODE
(
c_identity2
);
PATTERN_DECL_NODE
(
c_identity2_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul1_w
);
PATTERN_DECL_NODE
(
matmul2_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
matmul1_out
);
PATTERN_DECL_NODE
(
matmul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
c_allreduce_sum
);
PATTERN_DECL_NODE
(
c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// post layer_norm
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_c_identity
);
PATTERN_DECL_NODE
(
ffn_c_identity_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_act
);
PATTERN_DECL_NODE
(
ffn_act_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
(
...
...
@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderPass
:
public
FusePassBase
{
public:
MultiDevicesFusedMultiTransformerEncoderPass
();
virtual
~
MultiDevicesFusedMultiTransformerEncoderPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multi_devices_fused_multi_transformer_encoder"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
:
public
FusePassBase
{
public:
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
浏览文件 @
29eec2dd
此差异已折叠。
点击以展开。
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
29eec2dd
...
...
@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fuse_multi_transformer_layer_pass"
,
...
...
@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"fuse_multi_transformer_layer_pass"
,
//
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录