Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2b848aef
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2b848aef
编写于
2月 01, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
2月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fused attention pass fwd, create the fused_attention op. (#50125)
上级
e6d29e00
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
231 addition
and
53 deletion
+231
-53
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+214
-32
paddle/fluid/framework/ir/fused_attention_pass.h
paddle/fluid/framework/ir/fused_attention_pass.h
+11
-13
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+6
-8
未找到文件。
paddle/fluid/framework/ir/fused_attention_pass.cc
浏览文件 @
2b848aef
...
@@ -22,7 +22,6 @@ namespace patterns {
...
@@ -22,7 +22,6 @@ namespace patterns {
PDNode
*
FusedAttentionPattern
::
operator
()(
PDNode
*
x
,
PDNode
*
FusedAttentionPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
do_dropout
,
bool
add_residual
)
{
bool
add_residual
)
{
...
@@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
...
@@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
out_linear_dropout_node
->
LinksFrom
({
out_linear_ele_add_out_node
})
out_linear_dropout_node
->
LinksFrom
({
out_linear_ele_add_out_node
})
.
LinksTo
({
out_linear_dropout_mask_node
,
out_linear_dropout_out_node
});
.
LinksTo
({
out_linear_dropout_mask_node
,
out_linear_dropout_out_node
});
if
(
!
add_residual
&&
!
post
_layer_norm
)
{
if
(
!
add_residual
&&
pre
_layer_norm
)
{
return
out_linear_dropout_out_node
;
return
out_linear_dropout_out_node
;
}
}
...
@@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
...
@@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
residual_ele_add_node
->
LinksFrom
({
x
,
out_linear_dropout_out_node
})
residual_ele_add_node
->
LinksFrom
({
x
,
out_linear_dropout_out_node
})
.
LinksTo
({
residual_ele_add_out_node
});
.
LinksTo
({
residual_ele_add_out_node
});
if
(
!
post
_layer_norm
)
{
if
(
pre
_layer_norm
)
{
return
residual_ele_add_out_node
;
return
residual_ele_add_out_node
;
}
}
}
}
...
@@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
...
@@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
PDNode
*
FusedAttentionGradPattern
::
operator
()(
PDNode
*
x
,
PDNode
*
FusedAttentionGradPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
do_dropout
,
bool
add_residual
)
{
bool
add_residual
)
{
// post layer norm
// post layer norm
PDNode
*
post_layer_norm_grad_out_node
{
nullptr
};
PDNode
*
post_layer_norm_grad_out_node
{
nullptr
};
if
(
post
_layer_norm
)
{
if
(
!
pre
_layer_norm
)
{
auto
*
post_layer_norm_grad_node
=
auto
*
post_layer_norm_grad_node
=
pattern
->
NewNode
(
post_layer_norm_grad_op_repr
())
pattern
->
NewNode
(
post_layer_norm_grad_op_repr
())
->
assert_is_op
(
"layer_norm_grad"
);
->
assert_is_op
(
"layer_norm_grad"
);
...
@@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
...
@@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
PDNode
*
residual_ele_add_grad_x_grad_node
{
nullptr
};
PDNode
*
residual_ele_add_grad_x_grad_node
{
nullptr
};
if
(
add_residual
)
{
if
(
add_residual
)
{
PDNode
*
ele_add_grad_input
=
x
;
PDNode
*
ele_add_grad_input
=
x
;
if
(
post
_layer_norm
)
{
if
(
!
pre
_layer_norm
)
{
ele_add_grad_input
=
post_layer_norm_grad_out_node
;
ele_add_grad_input
=
post_layer_norm_grad_out_node
;
}
}
auto
*
residual_ele_add_grad_node
=
auto
*
residual_ele_add_grad_node
=
...
@@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
...
@@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
// get the real input x for dropout grad
// get the real input x for dropout grad
PDNode
*
out_linear_grad_input_node
=
x
;
PDNode
*
out_linear_grad_input_node
=
x
;
if
(
post
_layer_norm
&&
!
add_residual
)
{
if
(
!
pre
_layer_norm
&&
!
add_residual
)
{
out_linear_grad_input_node
=
post_layer_norm_grad_out_node
;
out_linear_grad_input_node
=
post_layer_norm_grad_out_node
;
}
else
if
(
add_residual
)
{
}
else
if
(
add_residual
)
{
out_linear_grad_input_node
=
residual_ele_add_grad_out_node
;
out_linear_grad_input_node
=
residual_ele_add_grad_out_node
;
...
@@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
...
@@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
graph
=
PreMaskDropRes
Post
Fwd
(
graph
);
graph
=
PreMaskDropResFwd
(
graph
);
graph
=
PreMaskDropRes
Post
Bwd
(
graph
);
graph
=
PreMaskDropResBwd
(
graph
);
}
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropRes
Post
Fwd
(
Graph
*
graph
)
const
{
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResFwd
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
...
@@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
...
@@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern
(
x
,
fused_attention_pattern
(
x
,
/* pre_layer_norm */
true
,
/* pre_layer_norm */
true
,
/* post_layer_norm */
true
,
/* has_attn_mask */
true
,
/* has_attn_mask */
true
,
/* do_dropout */
true
,
/* do_dropout */
true
,
/* add_residual */
true
);
/* add_residual */
true
);
...
@@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
...
@@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern
);
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
OpDesc
fused_attention_op_desc
(
pre_layer_norm_op_node
->
Op
()
->
Block
());
fused_attention_op_desc
.
SetType
(
"fused_attention"
);
fused_attention_op_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
fused_attention_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_scale_node
,
pre_layer_norm_scale
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_bias_node
,
pre_layer_norm_bias
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_out_node
,
pre_layer_norm_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_mean_node
,
pre_layer_norm_mean
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_variance_node
,
pre_layer_norm_variance
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"LnScale"
,
{
pre_layer_norm_scale_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"LnBias"
,
{
pre_layer_norm_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnOut"
,
{
pre_layer_norm_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnMean"
,
{
pre_layer_norm_mean_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnVariance"
,
{
pre_layer_norm_variance_node
->
Name
()});
fused_attention_op_desc
.
SetAttr
(
"epsilon"
,
PADDLE_GET_CONST
(
float
,
pre_layer_norm_op_node
->
Op
()
->
GetAttr
(
"epsilon"
)));
fused_attention_op_desc
.
SetAttr
(
"transpose_qkv_wb"
,
true
);
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
fuse_qkv_reshape_op_node
->
Op
()
->
GetAttr
(
"shape"
));
fused_attention_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
post_layer_norm_op_node
,
post_layer_norm_op
,
fused_attention_pattern
);
fuse_qkv_matmul_w_node
,
fuse_qkv_matmul_w
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_out_node
,
fuse_qkv_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_bias_node
,
fuse_qkv_ele_add_bias
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_out_node
,
fuse_qkv_ele_add_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_transpose_out_node
,
fuse_qkv_transpose_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"QKVW"
,
{
fuse_qkv_matmul_w_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"QKVBias"
,
{
fuse_qkv_ele_add_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKVOut"
,
{
fuse_qkv_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKVBiasOut"
,
{
fuse_qkv_ele_add_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"TransposeOut2"
,
{
fuse_qkv_transpose_out_node
->
Name
()});
// TODO(Yuang Liu): finish the handler
GET_IR_NODE_FROM_SUBGRAPH
(
qk_matmul_out_node
,
qk_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
add_mask_ele_add_mask_node
,
add_mask_ele_add_mask
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
add_mask_ele_add_out_node
,
add_mask_ele_add_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_softmax_out_node
,
qk_softmax_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_out_node
,
attn_dropout_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_mask_node
,
attn_dropout_mask
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_matmul_out_node
,
qkv_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_reshape_out_node
,
qkv_reshape_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetOutput
(
"QKOut"
,
{
qk_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"SrcMask"
,
{
add_mask_ele_add_mask_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"SrcMaskOut"
,
{
add_mask_ele_add_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"SoftmaxOut"
,
{
qk_softmax_out_node
->
Name
()});
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_rate"
,
PADDLE_GET_CONST
(
float
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_prob"
)));
fused_attention_op_desc
.
SetAttr
(
"is_test"
,
PADDLE_GET_CONST
(
bool
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"is_test"
)));
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_fix_seed"
,
PADDLE_GET_CONST
(
bool
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"fix_seed"
)));
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_seed"
,
PADDLE_GET_CONST
(
int
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"seed"
)));
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_implementation"
,
PADDLE_GET_CONST
(
std
::
string
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_implementation"
)));
fused_attention_op_desc
.
SetOutput
(
"AttnDropoutMaskOut"
,
{
attn_dropout_mask_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"AttnDropoutOut"
,
{
attn_dropout_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKTVOut"
,
{
qkv_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"FMHAOut"
,
{
qkv_reshape_out_node
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_w_node
,
out_linear_matmul_w
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_out_node
,
out_linear_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_bias_node
,
out_linear_ele_add_bias
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_out_node
,
out_linear_ele_add_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"OutLinearW"
,
{
out_linear_matmul_w_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"OutLinearBias"
,
{
out_linear_ele_add_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"OutLinearOut"
,
{
out_linear_matmul_out_node
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_dropout_mask_node
,
out_linear_dropout_mask
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetAttr
(
"dropout_rate"
,
PADDLE_GET_CONST
(
float
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_prob"
)));
fused_attention_op_desc
.
SetAttr
(
"dropout_fix_seed"
,
PADDLE_GET_CONST
(
bool
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"fix_seed"
)));
fused_attention_op_desc
.
SetAttr
(
"dropout_seed"
,
PADDLE_GET_CONST
(
int
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"seed"
)));
fused_attention_op_desc
.
SetAttr
(
"dropout_implementation"
,
PADDLE_GET_CONST
(
std
::
string
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_implementation"
)));
fused_attention_op_desc
.
SetOutput
(
"DropoutMaskOut"
,
{
out_linear_dropout_mask_node
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_out_node
,
residual_ele_add_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetAttr
(
"add_residual"
,
true
);
fused_attention_op_desc
.
SetOutput
(
"Y"
,
{
residual_ele_add_out_node
->
Name
()});
auto
fused_attention_node
=
g
->
CreateOpNode
(
&
fused_attention_op_desc
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_attention_node
);
IR_NODE_LINK_TO
(
pre_layer_norm_scale_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
pre_layer_norm_bias_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
fuse_qkv_matmul_w_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
fuse_qkv_ele_add_bias_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
add_mask_ele_add_mask_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
out_linear_matmul_w_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
out_linear_ele_add_bias_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
pre_layer_norm_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
pre_layer_norm_mean_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
pre_layer_norm_variance_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
fuse_qkv_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
fuse_qkv_ele_add_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
fuse_qkv_transpose_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qk_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
add_mask_ele_add_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qk_softmax_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
attn_dropout_mask_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
attn_dropout_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qkv_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qkv_reshape_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
out_linear_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
out_linear_dropout_mask_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
residual_ele_add_out_node
);
GraphSafeRemoveNodes
(
g
,
GraphSafeRemoveNodes
(
g
,
{
pre_layer_norm_op_node
,
{
pre_layer_norm_op_node
,
...
@@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
...
@@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
out_linear_matmul_op_node
,
out_linear_matmul_op_node
,
out_linear_ele_add_op_node
,
out_linear_ele_add_op_node
,
out_linear_dropout_op_node
,
out_linear_dropout_op_node
,
residual_ele_add_op_node
,
residual_ele_add_op_node
});
post_layer_norm_op_node
});
found_fused_attention
++
;
found_fused_attention
++
;
};
};
...
@@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
...
@@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
return
graph
;
return
graph
;
}
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropRes
Post
Bwd
(
Graph
*
graph
)
const
{
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResBwd
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
AsInput
()
->
AsInput
()
->
assert_is_op_input
(
"
layer_norm_grad"
,
"Y
@GRAD"
);
->
assert_is_op_input
(
"
elementwise_add_grad"
,
"Out
@GRAD"
);
patterns
::
FusedAttentionGradPattern
fused_attention_grad_pattern
(
patterns
::
FusedAttentionGradPattern
fused_attention_grad_pattern
(
gpd
.
mutable_pattern
(),
"fused_attention_grad_pattern"
);
gpd
.
mutable_pattern
(),
"fused_attention_grad_pattern"
);
fused_attention_grad_pattern
(
x
,
fused_attention_grad_pattern
(
x
,
/* pre_layer_norm */
true
,
/* pre_layer_norm */
true
,
/* post_layer_norm */
true
,
/* has_attn_mask */
true
,
/* has_attn_mask */
true
,
/* do_dropout */
true
,
/* do_dropout */
true
,
/* add_residual */
true
);
/* add_residual */
true
);
...
@@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
...
@@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention backward pass's fusion"
;
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention backward pass's fusion"
;
GET_IR_NODE_FROM_SUBGRAPH
(
post_layer_norm_grad_op_node
,
post_layer_norm_grad_op
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_grad_op_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_grad_op_node
,
residual_ele_add_grad_op
,
residual_ele_add_grad_op
,
fused_attention_grad_pattern
);
fused_attention_grad_pattern
);
...
@@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
...
@@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
// TODO(Yuang Liu): finish the handler
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes
(
GraphSafeRemoveNodes
(
g
,
g
,
{
post_layer_norm_grad_op_node
,
residual_ele_add_grad_op_node
,
{
residual_ele_add_grad_op_node
,
out_linear_dropout_grad_op_node
,
out_linear_ele_add_grad_op_node
,
out_linear_dropout_grad_op_node
,
out_linear_matmul_grad_op_node
,
qkv_reshape_grad_op_node
,
out_linear_ele_add_grad_op_node
,
qkv_transpose_grad_op_node
,
qkv_matmul_grad_op_node
,
out_linear_matmul_grad_op_node
,
attn_dropout_grad_op_node
,
qk_softmax_grad_op_node
,
qkv_reshape_grad_op_node
,
add_mask_ele_add_grad_op_node
,
qk_scale_grad_op_node
,
qkv_transpose_grad_op_node
,
qk_matmul_grad_op_node
,
fuse_qkv_split_grad_op_node
,
qkv_matmul_grad_op_node
,
fuse_qkv_transpose_grad_op_node
,
fuse_qkv_reshape_grad_op_node
,
attn_dropout_grad_op_node
,
fuse_qkv_ele_add_grad_op_node
,
fuse_qkv_matmul_grad_op_node
,
qk_softmax_grad_op_node
,
pre_layer_norm_grad_op_node
,
grad_accumulation_sum_op_node
});
add_mask_ele_add_grad_op_node
,
qk_scale_grad_op_node
,
qk_matmul_grad_op_node
,
fuse_qkv_split_grad_op_node
,
fuse_qkv_transpose_grad_op_node
,
fuse_qkv_reshape_grad_op_node
,
fuse_qkv_ele_add_grad_op_node
,
fuse_qkv_matmul_grad_op_node
,
pre_layer_norm_grad_op_node
,
grad_accumulation_sum_op_node
});
found_fused_attention
++
;
found_fused_attention
++
;
};
};
...
...
paddle/fluid/framework/ir/fused_attention_pass.h
浏览文件 @
2b848aef
...
@@ -28,7 +28,7 @@ namespace patterns {
...
@@ -28,7 +28,7 @@ namespace patterns {
// Declare patterns for multi head attention.
// Declare patterns for multi head attention.
// Can detect:
// Can detect:
// 1. Pre layer norm
, post layer norm or sandwich
layer norm.
// 1. Pre layer norm
or post
layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
// 4. Add residual to the out linear result or not.
...
@@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase {
...
@@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// do pre ln or not
bool
pre_layer_norm
,
// do pre ln or not
bool
post_layer_norm
,
// do post ln or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
bool
add_residual
);
// add residual to out linear or not
// pre layer norm
// pre layer norm
PATTERN_DECL_NODE
(
pre_layer_norm_op
);
PATTERN_DECL_NODE
(
pre_layer_norm_op
);
...
@@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase {
...
@@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// pre ln
bool
pre_layer_norm
,
// pre ln
bool
post_layer_norm
,
// post ln
bool
has_attn_mask
,
// add attn mask to qk or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
bool
add_residual
);
// add residual to out linear or not
// post layer norm grad
// post layer norm grad
PATTERN_DECL_NODE
(
post_layer_norm_grad_op
);
PATTERN_DECL_NODE
(
post_layer_norm_grad_op
);
...
@@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase {
...
@@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part.
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
// If false, the function name won't contain an abbreviation for it.
ir
::
Graph
*
PreMaskDropRes
Post
Fwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResFwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropRes
Post
Bwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResBwd
(
Graph
*
graph
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
浏览文件 @
2b848aef
...
@@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer):
...
@@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer):
num_heads
,
num_heads
,
add_residual
=
True
,
add_residual
=
True
,
pre_ln
=
True
,
pre_ln
=
True
,
post_ln
=
False
,
attn_dropout
=
True
,
attn_dropout
=
True
,
):
):
super
(
MultiHeadAttention
,
self
).
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
...
@@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer):
...
@@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer):
self
.
add_residual
=
add_residual
self
.
add_residual
=
add_residual
self
.
pre_ln
=
pre_ln
self
.
pre_ln
=
pre_ln
self
.
post_ln
=
post_ln
self
.
attn_dropout
=
attn_dropout
self
.
attn_dropout
=
attn_dropout
self
.
head_dim
=
embed_dim
//
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
...
@@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer):
...
@@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer):
if
self
.
add_residual
:
if
self
.
add_residual
:
out
=
residual
+
out
out
=
residual
+
out
if
self
.
post
_ln
:
if
not
self
.
pre
_ln
:
# post layer norm
# post layer norm
out
=
self
.
norm2
(
out
)
out
=
self
.
norm2
(
out
)
...
@@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase):
...
@@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
add_residual
=
True
self
.
add_residual
=
True
self
.
pre_ln
=
True
self
.
pre_ln
=
True
self
.
post_ln
=
True
self
.
attn_dropout
=
True
self
.
attn_dropout
=
True
self
.
add_mask
=
True
self
.
add_mask
=
True
...
@@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase):
...
@@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase):
).
astype
(
'float32'
)
).
astype
(
'float32'
)
main_prog
=
paddle
.
static
.
Program
()
main_prog
=
paddle
.
static
.
Program
()
main_prog
.
random_seed
=
1234
startup_prog
=
paddle
.
static
.
Program
()
startup_prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
...
@@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase):
...
@@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase):
num_heads
,
num_heads
,
add_residual
=
self
.
add_residual
,
add_residual
=
self
.
add_residual
,
pre_ln
=
self
.
pre_ln
,
pre_ln
=
self
.
pre_ln
,
post_ln
=
self
.
post_ln
,
attn_dropout
=
self
.
attn_dropout
,
attn_dropout
=
self
.
attn_dropout
,
)
)
...
@@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase):
...
@@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase):
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
ops
=
main_prog
.
global_block
().
ops
ops
=
main_prog
.
global_block
().
ops
assert
ops
[
2
].
type
==
'reduce_mean'
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
4
].
type
==
'reduce_mean_grad'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
# two ops for linear, one op for reduce mean
# two ops for linear, one op for reduce mean
# one fill constant
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
# the eighth op should be the optimizer
assert
ops
[
7
].
type
==
'sgd'
assert
ops
[
8
].
type
==
'sgd'
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录