Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2b848aef
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2b848aef
编写于
2月 01, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
2月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fused attention pass fwd, create the fused_attention op. (#50125)
上级
e6d29e00
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
231 addition
and
53 deletion
+231
-53
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+214
-32
paddle/fluid/framework/ir/fused_attention_pass.h
paddle/fluid/framework/ir/fused_attention_pass.h
+11
-13
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+6
-8
未找到文件。
paddle/fluid/framework/ir/fused_attention_pass.cc
浏览文件 @
2b848aef
...
...
@@ -22,7 +22,6 @@ namespace patterns {
PDNode
*
FusedAttentionPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
add_residual
)
{
...
...
@@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
out_linear_dropout_node
->
LinksFrom
({
out_linear_ele_add_out_node
})
.
LinksTo
({
out_linear_dropout_mask_node
,
out_linear_dropout_out_node
});
if
(
!
add_residual
&&
!
post
_layer_norm
)
{
if
(
!
add_residual
&&
pre
_layer_norm
)
{
return
out_linear_dropout_out_node
;
}
...
...
@@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
residual_ele_add_node
->
LinksFrom
({
x
,
out_linear_dropout_out_node
})
.
LinksTo
({
residual_ele_add_out_node
});
if
(
!
post
_layer_norm
)
{
if
(
pre
_layer_norm
)
{
return
residual_ele_add_out_node
;
}
}
...
...
@@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
PDNode
*
FusedAttentionGradPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
add_residual
)
{
// post layer norm
PDNode
*
post_layer_norm_grad_out_node
{
nullptr
};
if
(
post
_layer_norm
)
{
if
(
!
pre
_layer_norm
)
{
auto
*
post_layer_norm_grad_node
=
pattern
->
NewNode
(
post_layer_norm_grad_op_repr
())
->
assert_is_op
(
"layer_norm_grad"
);
...
...
@@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
PDNode
*
residual_ele_add_grad_x_grad_node
{
nullptr
};
if
(
add_residual
)
{
PDNode
*
ele_add_grad_input
=
x
;
if
(
post
_layer_norm
)
{
if
(
!
pre
_layer_norm
)
{
ele_add_grad_input
=
post_layer_norm_grad_out_node
;
}
auto
*
residual_ele_add_grad_node
=
...
...
@@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
// get the real input x for dropout grad
PDNode
*
out_linear_grad_input_node
=
x
;
if
(
post
_layer_norm
&&
!
add_residual
)
{
if
(
!
pre
_layer_norm
&&
!
add_residual
)
{
out_linear_grad_input_node
=
post_layer_norm_grad_out_node
;
}
else
if
(
add_residual
)
{
out_linear_grad_input_node
=
residual_ele_add_grad_out_node
;
...
...
@@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
graph
=
PreMaskDropRes
Post
Fwd
(
graph
);
graph
=
PreMaskDropRes
Post
Bwd
(
graph
);
graph
=
PreMaskDropResFwd
(
graph
);
graph
=
PreMaskDropResBwd
(
graph
);
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropRes
Post
Fwd
(
Graph
*
graph
)
const
{
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResFwd
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
...
...
@@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern
(
x
,
/* pre_layer_norm */
true
,
/* post_layer_norm */
true
,
/* has_attn_mask */
true
,
/* do_dropout */
true
,
/* add_residual */
true
);
...
...
@@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
OpDesc
fused_attention_op_desc
(
pre_layer_norm_op_node
->
Op
()
->
Block
());
fused_attention_op_desc
.
SetType
(
"fused_attention"
);
fused_attention_op_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
fused_attention_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_scale_node
,
pre_layer_norm_scale
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_bias_node
,
pre_layer_norm_bias
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_out_node
,
pre_layer_norm_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_mean_node
,
pre_layer_norm_mean
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_variance_node
,
pre_layer_norm_variance
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"LnScale"
,
{
pre_layer_norm_scale_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"LnBias"
,
{
pre_layer_norm_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnOut"
,
{
pre_layer_norm_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnMean"
,
{
pre_layer_norm_mean_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnVariance"
,
{
pre_layer_norm_variance_node
->
Name
()});
fused_attention_op_desc
.
SetAttr
(
"epsilon"
,
PADDLE_GET_CONST
(
float
,
pre_layer_norm_op_node
->
Op
()
->
GetAttr
(
"epsilon"
)));
fused_attention_op_desc
.
SetAttr
(
"transpose_qkv_wb"
,
true
);
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
fuse_qkv_reshape_op_node
->
Op
()
->
GetAttr
(
"shape"
));
fused_attention_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
GET_IR_NODE_FROM_SUBGRAPH
(
post_layer_norm_op_node
,
post_layer_norm_op
,
fused_attention_pattern
);
fuse_qkv_matmul_w_node
,
fuse_qkv_matmul_w
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_out_node
,
fuse_qkv_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_bias_node
,
fuse_qkv_ele_add_bias
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_out_node
,
fuse_qkv_ele_add_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_transpose_out_node
,
fuse_qkv_transpose_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"QKVW"
,
{
fuse_qkv_matmul_w_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"QKVBias"
,
{
fuse_qkv_ele_add_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKVOut"
,
{
fuse_qkv_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKVBiasOut"
,
{
fuse_qkv_ele_add_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"TransposeOut2"
,
{
fuse_qkv_transpose_out_node
->
Name
()});
// TODO(Yuang Liu): finish the handler
GET_IR_NODE_FROM_SUBGRAPH
(
qk_matmul_out_node
,
qk_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
add_mask_ele_add_mask_node
,
add_mask_ele_add_mask
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
add_mask_ele_add_out_node
,
add_mask_ele_add_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_softmax_out_node
,
qk_softmax_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_out_node
,
attn_dropout_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_mask_node
,
attn_dropout_mask
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_matmul_out_node
,
qkv_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_reshape_out_node
,
qkv_reshape_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetOutput
(
"QKOut"
,
{
qk_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"SrcMask"
,
{
add_mask_ele_add_mask_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"SrcMaskOut"
,
{
add_mask_ele_add_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"SoftmaxOut"
,
{
qk_softmax_out_node
->
Name
()});
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_rate"
,
PADDLE_GET_CONST
(
float
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_prob"
)));
fused_attention_op_desc
.
SetAttr
(
"is_test"
,
PADDLE_GET_CONST
(
bool
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"is_test"
)));
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_fix_seed"
,
PADDLE_GET_CONST
(
bool
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"fix_seed"
)));
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_seed"
,
PADDLE_GET_CONST
(
int
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"seed"
)));
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_implementation"
,
PADDLE_GET_CONST
(
std
::
string
,
attn_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_implementation"
)));
fused_attention_op_desc
.
SetOutput
(
"AttnDropoutMaskOut"
,
{
attn_dropout_mask_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"AttnDropoutOut"
,
{
attn_dropout_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKTVOut"
,
{
qkv_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"FMHAOut"
,
{
qkv_reshape_out_node
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_w_node
,
out_linear_matmul_w
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_out_node
,
out_linear_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_bias_node
,
out_linear_ele_add_bias
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_out_node
,
out_linear_ele_add_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"OutLinearW"
,
{
out_linear_matmul_w_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"OutLinearBias"
,
{
out_linear_ele_add_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"OutLinearOut"
,
{
out_linear_matmul_out_node
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_dropout_mask_node
,
out_linear_dropout_mask
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetAttr
(
"dropout_rate"
,
PADDLE_GET_CONST
(
float
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_prob"
)));
fused_attention_op_desc
.
SetAttr
(
"dropout_fix_seed"
,
PADDLE_GET_CONST
(
bool
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"fix_seed"
)));
fused_attention_op_desc
.
SetAttr
(
"dropout_seed"
,
PADDLE_GET_CONST
(
int
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"seed"
)));
fused_attention_op_desc
.
SetAttr
(
"dropout_implementation"
,
PADDLE_GET_CONST
(
std
::
string
,
out_linear_dropout_op_node
->
Op
()
->
GetAttr
(
"dropout_implementation"
)));
fused_attention_op_desc
.
SetOutput
(
"DropoutMaskOut"
,
{
out_linear_dropout_mask_node
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_out_node
,
residual_ele_add_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetAttr
(
"add_residual"
,
true
);
fused_attention_op_desc
.
SetOutput
(
"Y"
,
{
residual_ele_add_out_node
->
Name
()});
auto
fused_attention_node
=
g
->
CreateOpNode
(
&
fused_attention_op_desc
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_attention_node
);
IR_NODE_LINK_TO
(
pre_layer_norm_scale_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
pre_layer_norm_bias_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
fuse_qkv_matmul_w_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
fuse_qkv_ele_add_bias_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
add_mask_ele_add_mask_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
out_linear_matmul_w_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
out_linear_ele_add_bias_node
,
fused_attention_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
pre_layer_norm_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
pre_layer_norm_mean_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
pre_layer_norm_variance_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
fuse_qkv_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
fuse_qkv_ele_add_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
fuse_qkv_transpose_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qk_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
add_mask_ele_add_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qk_softmax_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
attn_dropout_mask_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
attn_dropout_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qkv_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
qkv_reshape_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
out_linear_matmul_out_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
out_linear_dropout_mask_node
);
IR_NODE_LINK_TO
(
fused_attention_node
,
residual_ele_add_out_node
);
GraphSafeRemoveNodes
(
g
,
{
pre_layer_norm_op_node
,
...
...
@@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
out_linear_matmul_op_node
,
out_linear_ele_add_op_node
,
out_linear_dropout_op_node
,
residual_ele_add_op_node
,
post_layer_norm_op_node
});
residual_ele_add_op_node
});
found_fused_attention
++
;
};
...
...
@@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
return
graph
;
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropRes
Post
Bwd
(
Graph
*
graph
)
const
{
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResBwd
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
AsInput
()
->
assert_is_op_input
(
"
layer_norm_grad"
,
"Y
@GRAD"
);
->
assert_is_op_input
(
"
elementwise_add_grad"
,
"Out
@GRAD"
);
patterns
::
FusedAttentionGradPattern
fused_attention_grad_pattern
(
gpd
.
mutable_pattern
(),
"fused_attention_grad_pattern"
);
fused_attention_grad_pattern
(
x
,
/* pre_layer_norm */
true
,
/* post_layer_norm */
true
,
/* has_attn_mask */
true
,
/* do_dropout */
true
,
/* add_residual */
true
);
...
...
@@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention backward pass's fusion"
;
GET_IR_NODE_FROM_SUBGRAPH
(
post_layer_norm_grad_op_node
,
post_layer_norm_grad_op
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_grad_op_node
,
residual_ele_add_grad_op
,
fused_attention_grad_pattern
);
...
...
@@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes
(
g
,
{
post_layer_norm_grad_op_node
,
residual_ele_add_grad_op_node
,
out_linear_dropout_grad_op_node
,
out_linear_ele_add_grad_op_node
,
out_linear_matmul_grad_op_node
,
qkv_reshape_grad_op_node
,
qkv_transpose_grad_op_node
,
qkv_matmul_grad_op_node
,
attn_dropout_grad_op_node
,
qk_softmax_grad_op_node
,
add_mask_ele_add_grad_op_node
,
qk_scale_grad_op_node
,
qk_matmul_grad_op_node
,
fuse_qkv_split_grad_op_node
,
fuse_qkv_transpose_grad_op_node
,
fuse_qkv_reshape_grad_op_node
,
fuse_qkv_ele_add_grad_op_node
,
fuse_qkv_matmul_grad_op_node
,
pre_layer_norm_grad_op_node
,
grad_accumulation_sum_op_node
});
GraphSafeRemoveNodes
(
g
,
{
residual_ele_add_grad_op_node
,
out_linear_dropout_grad_op_node
,
out_linear_ele_add_grad_op_node
,
out_linear_matmul_grad_op_node
,
qkv_reshape_grad_op_node
,
qkv_transpose_grad_op_node
,
qkv_matmul_grad_op_node
,
attn_dropout_grad_op_node
,
qk_softmax_grad_op_node
,
add_mask_ele_add_grad_op_node
,
qk_scale_grad_op_node
,
qk_matmul_grad_op_node
,
fuse_qkv_split_grad_op_node
,
fuse_qkv_transpose_grad_op_node
,
fuse_qkv_reshape_grad_op_node
,
fuse_qkv_ele_add_grad_op_node
,
fuse_qkv_matmul_grad_op_node
,
pre_layer_norm_grad_op_node
,
grad_accumulation_sum_op_node
});
found_fused_attention
++
;
};
...
...
paddle/fluid/framework/ir/fused_attention_pass.h
浏览文件 @
2b848aef
...
...
@@ -28,7 +28,7 @@ namespace patterns {
// Declare patterns for multi head attention.
// Can detect:
// 1. Pre layer norm
, post layer norm or sandwich
layer norm.
// 1. Pre layer norm
or post
layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
...
...
@@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// do pre ln or not
bool
post_layer_norm
,
// do post ln or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
bool
pre_layer_norm
,
// do pre ln or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
// pre layer norm
PATTERN_DECL_NODE
(
pre_layer_norm_op
);
...
...
@@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// pre ln
bool
post_layer_norm
,
// post ln
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
bool
pre_layer_norm
,
// pre ln
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
// post layer norm grad
PATTERN_DECL_NODE
(
post_layer_norm_grad_op
);
...
...
@@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
ir
::
Graph
*
PreMaskDropRes
Post
Fwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResFwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropRes
Post
Bwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResBwd
(
Graph
*
graph
)
const
;
};
}
// namespace ir
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
浏览文件 @
2b848aef
...
...
@@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer):
num_heads
,
add_residual
=
True
,
pre_ln
=
True
,
post_ln
=
False
,
attn_dropout
=
True
,
):
super
(
MultiHeadAttention
,
self
).
__init__
()
...
...
@@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer):
self
.
add_residual
=
add_residual
self
.
pre_ln
=
pre_ln
self
.
post_ln
=
post_ln
self
.
attn_dropout
=
attn_dropout
self
.
head_dim
=
embed_dim
//
num_heads
...
...
@@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer):
if
self
.
add_residual
:
out
=
residual
+
out
if
self
.
post
_ln
:
if
not
self
.
pre
_ln
:
# post layer norm
out
=
self
.
norm2
(
out
)
...
...
@@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase):
def
setUp
(
self
):
self
.
add_residual
=
True
self
.
pre_ln
=
True
self
.
post_ln
=
True
self
.
attn_dropout
=
True
self
.
add_mask
=
True
...
...
@@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase):
).
astype
(
'float32'
)
main_prog
=
paddle
.
static
.
Program
()
main_prog
.
random_seed
=
1234
startup_prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
...
...
@@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase):
num_heads
,
add_residual
=
self
.
add_residual
,
pre_ln
=
self
.
pre_ln
,
post_ln
=
self
.
post_ln
,
attn_dropout
=
self
.
attn_dropout
,
)
...
...
@@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase):
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
ops
=
main_prog
.
global_block
().
ops
assert
ops
[
2
].
type
==
'reduce_mean'
assert
ops
[
4
].
type
==
'reduce_mean_grad'
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert
ops
[
7
].
type
==
'sgd'
assert
ops
[
8
].
type
==
'sgd'
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录