Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
29eec2dd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29eec2dd
编写于
1月 04, 2023
作者:
L
lzy
提交者:
GitHub
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)
上级
a2d7e1d7
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
2776 addition
and
1241 deletion
+2776
-1241
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+50
-36
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+6
-6
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+2341
-1140
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+164
-20
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+213
-39
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
29eec2dd
...
...
@@ -31,6 +31,8 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
static
const
std
::
unordered_set
<
std
::
string
>
FFN_ACTS
{
"relu"
,
"gelu"
};
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
...
...
@@ -359,13 +361,13 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
...
...
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
...
...
@@ -678,13 +680,13 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
...
...
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
...
...
@@ -1026,13 +1028,13 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op
_input
(
"gelu"
);
->
assert_is_op
s_input
(
FFN_ACTS
);
auto
*
ffn_
gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_
gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu
_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_
act
=
pattern
->
NewNode
(
ffn_act_repr
())
->
assert_is_ops
(
FFN_ACTS
);
auto
*
ffn_
act_out_var
=
pattern
->
NewNode
(
ffn_act
_out_repr
())
->
assert_is_ops_output
(
FFN_ACTS
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
...
...
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_
gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
gelu
_out_var
,
ffn_matmul1_w_var
})
ffn_
act
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_act
_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_
act
_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
...
...
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
...
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
true
);
...
...
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
...
...
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
...
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
gelu
_out
,
ffn_
act
,
ffn_
act
_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
...
...
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
auto
*
matmul0_op
=
matmul0
->
Op
();
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
...
...
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
...
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
...
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
...
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
gelu
_out
,
ffn_
act
,
ffn_
act
_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
...
...
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_act
,
Node
*
ffn_output
)
{
auto
*
matmul_linear_op
=
matmul_linear
->
Op
();
auto
*
ffn_matmul_1_op
=
ffn_matmul1
->
Op
();
...
...
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"act_method"
,
ffn_act
->
Op
()
->
Type
());
// output dropout attribute
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
0.0
f
);
...
...
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act
,
ffn_act
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_
gelu_out
,
ffn_gelu
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
ffn_
act_out
,
ffn_act
_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
...
...
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_act
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
...
...
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_
gelu
,
ffn_
gelu
_out
,
ffn_
act
,
ffn_
act
_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
浏览文件 @
29eec2dd
...
...
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
浏览文件 @
29eec2dd
因为 它太大了无法显示 source diff 。你可以改为
查看blob
。
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
浏览文件 @
29eec2dd
...
...
@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
...
...
@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
...
...
@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// while loop
PATTERN_DECL_NODE
(
while0
);
// post layer_norm
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
...
...
@@ -212,11 +216,127 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_act
);
PATTERN_DECL_NODE
(
ffn_act_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
struct
MultiDevicesFusedMultiTransformerEncoderPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multi_devices_fused_multi_transformer_encoder"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
c_identity0
);
PATTERN_DECL_NODE
(
c_identity0_out
);
PATTERN_DECL_NODE
(
c_identity1
);
PATTERN_DECL_NODE
(
c_identity1_out
);
PATTERN_DECL_NODE
(
c_identity2
);
PATTERN_DECL_NODE
(
c_identity2_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul1_w
);
PATTERN_DECL_NODE
(
matmul2_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
matmul1_out
);
PATTERN_DECL_NODE
(
matmul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
scale_q
);
PATTERN_DECL_NODE
(
scale_q_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
c_allreduce_sum
);
PATTERN_DECL_NODE
(
c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// post layer_norm
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_c_identity
);
PATTERN_DECL_NODE
(
ffn_c_identity_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_act
);
PATTERN_DECL_NODE
(
ffn_act_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
...
...
@@ -224,6 +344,13 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
};
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...
...
@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_
gelu
);
PATTERN_DECL_NODE
(
ffn_
gelu
_out
);
PATTERN_DECL_NODE
(
ffn_
act
);
PATTERN_DECL_NODE
(
ffn_
act
_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
...
...
@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderPass
:
public
FusePassBase
{
public:
MultiDevicesFusedMultiTransformerEncoderPass
();
virtual
~
MultiDevicesFusedMultiTransformerEncoderPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multi_devices_fused_multi_transformer_encoder"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
:
public
FusePassBase
{
public:
...
...
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
浏览文件 @
29eec2dd
...
...
@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias
_
0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias
_
1"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias1"
,
{
1024
});
return
param_scope
;
}
...
...
@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
TEST
(
FusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (x, weights_0) matmul_v2 -> matmul_out0
// (x, weights_1) matmul_v2 -> matmul_out1
// (x, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
...
...
@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
...
...
@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_
out)
elementwise_add -> attention_out
// (eltadd_
linear)
elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm ->
ffn_
layer_norm_out
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
ln_out
,
weights_1
,
nullptr
,
false
,
true
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
ln_out
,
weights_2
,
nullptr
,
false
,
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
x
,
weights_0
,
nullptr
,
false
,
false
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
x
,
weights_1
,
nullptr
,
false
,
false
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
x
,
weights_2
,
nullptr
,
false
,
false
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
...
...
@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// Link to decoder while block
layers
.
while_loop
({
transpose_1
,
transpose_2
});
// q scale
auto
*
scale_q
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
transpose_0
,
transpose_1
,
nullptr
,
false
,
true
);
layers
.
matmul
_v2
(
scale_q
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
...
@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"
weightsl"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"
bias_l"
,
{
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
tru
e
);
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
fals
e
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// post LayerNorm
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
attention_out
,
ln_scale
,
ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
...
...
@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_
ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
layers
.
matmul_v2
(
ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
...
...
@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
layers
.
elementwise_add
(
attention_out
,
ffn_eltadd1_out
);
auto
*
ffn_out
=
layers
.
elementwise_add
(
ln_out
,
ffn_eltadd1_out
);
// FFN: post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
layers
.
layer_norm
(
ffn_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
...
@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
5
6
,
num_nodes_after
+
5
8
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
5
6
,
num_nodes_before
-
5
8
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
...
...
@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
.
IsPassCompatible
(
"fused_multi_transformer_encoder_pass"
));
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x) c_identity -> c_identity0_out
// (x) c_identity -> c_identity1_out
// (x) c_identity -> c_identity2_out
// (c_identity0_out, weights_0) matmul_v2 -> matmul_out0
// (c_identity1_out, weights_1) matmul_v2 -> matmul_out1
// (c_identity2_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
c_identity0_out
=
layers
.
c_identity
(
x
);
auto
*
c_identity1_out
=
layers
.
c_identity
(
x
);
auto
*
c_identity2_out
=
layers
.
c_identity
(
x
);
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
c_identity0_out
,
weights_0
,
nullptr
,
false
,
false
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
c_identity1_out
,
weights_1
,
nullptr
,
false
,
false
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
c_identity2_out
,
weights_2
,
nullptr
,
false
,
false
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
1024
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
matmul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
matmul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// q scale
auto
*
scale_q
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul_v2
(
scale_q
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
softmax_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"bias_l"
,
{
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
false
);
auto
*
c_allreduce_out
=
layers
.
c_allreduce_sum
(
linear_matmut_out
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
linear_eltadd_out
);
// post LayerNorm
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
attention_out
,
ln_scale
,
ln_bias
)[
0
];
auto
*
ffn_c_identity_out
=
layers
.
c_identity
(
ln_out
);
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_c_identity_out
,
ffn_weights0
,
nullptr
,
false
,
false
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
false
);
auto
*
ffn_allreduce_out
=
layers
.
c_allreduce_sum
(
ffn_matmul1_out
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
auto
*
ffn_out
=
layers
.
elementwise_add
(
ln_out
,
ffn_eltadd1_out
);
// FFN: post LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
layers
.
layer_norm
(
ffn_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"enable_int8"
,
new
bool
(
false
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"multi_devices_fused_multi_transformer_encoder_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get multi_devices_fused_multi_transformer_encoder_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
70
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
70
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multi_devices_fused_multi_transformer_encoder_pass"
));
}
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
...
...
@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
...
@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto
*
matmul_qk
=
layers
.
matmul_v2
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
scale_qk
=
layers
.
scale
(
matmul_qk
,
0.125
,
0
,
false
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
2
,
128
,
128
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
1
,
1
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
scale_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
...
...
@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
USE_PASS
(
fused_multi_transformer_encoder_pass
);
USE_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
29eec2dd
...
...
@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fuse_multi_transformer_layer_pass"
,
...
...
@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"fuse_multi_transformer_layer_pass"
,
//
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录