Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7e8ef328
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7e8ef328
编写于
2月 03, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
2月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fused attention pass backward op replace. (#50186)
上级
f2ec69b4
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
428 addition
and
14 deletion
+428
-14
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+381
-10
paddle/fluid/framework/ir/fused_attention_pass.h
paddle/fluid/framework/ir/fused_attention_pass.h
+36
-2
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+1
-1
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+10
-1
未找到文件。
paddle/fluid/framework/ir/fused_attention_pass.cc
浏览文件 @
7e8ef328
...
@@ -766,12 +766,15 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
...
@@ -766,12 +766,15 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusedAttentionPassCache
cache
;
graph
=
PreMaskDropResFwd
(
graph
);
graph
=
PreMaskDropResFwd
(
graph
,
&
cache
);
graph
=
PreMaskDropResBwd
(
graph
);
graph
=
PreMaskDropResBwd
(
graph
,
&
cache
);
cache
.
ResetCache
();
}
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResFwd
(
Graph
*
graph
)
const
{
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResFwd
(
Graph
*
graph
,
FusedAttentionPassCache
*
cache
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
...
@@ -792,6 +795,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -792,6 +795,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention pass's fusion"
;
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention pass's fusion"
;
int
block_id
=
g
->
GetBlockId
();
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_op_node
,
pre_layer_norm_op
,
fused_attention_pattern
);
pre_layer_norm_op_node
,
pre_layer_norm_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
...
@@ -833,9 +838,15 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -833,9 +838,15 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_w_node
,
fuse_qkv_matmul_w
,
fused_attention_pattern
);
std
::
string
cache_anchor_name
=
fuse_qkv_matmul_w_node
->
Var
()
->
Name
();
OpDesc
fused_attention_op_desc
(
pre_layer_norm_op_node
->
Op
()
->
Block
());
OpDesc
fused_attention_op_desc
(
pre_layer_norm_op_node
->
Op
()
->
Block
());
fused_attention_op_desc
.
SetType
(
"fused_attention"
);
fused_attention_op_desc
.
SetType
(
"fused_attention"
);
fused_attention_op_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"X"
,
block_id
),
subgraph
.
at
(
x
));
fused_attention_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_attention_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_scale_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_scale_node
,
...
@@ -860,6 +871,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -860,6 +871,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
{
pre_layer_norm_mean_node
->
Name
()});
{
pre_layer_norm_mean_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"LnVariance"
,
fused_attention_op_desc
.
SetOutput
(
"LnVariance"
,
{
pre_layer_norm_variance_node
->
Name
()});
{
pre_layer_norm_variance_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnScale"
,
block_id
),
pre_layer_norm_scale_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnBias"
,
block_id
),
pre_layer_norm_bias_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnOut"
,
block_id
),
pre_layer_norm_out_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnMean"
,
block_id
),
pre_layer_norm_mean_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnVariance"
,
block_id
),
pre_layer_norm_variance_node
);
fused_attention_op_desc
.
SetAttr
(
fused_attention_op_desc
.
SetAttr
(
"epsilon"
,
"epsilon"
,
PADDLE_GET_CONST
(
float
,
PADDLE_GET_CONST
(
float
,
...
@@ -869,8 +895,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -869,8 +895,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
fuse_qkv_reshape_op_node
->
Op
()
->
GetAttr
(
"shape"
));
std
::
vector
<
int
>
,
fuse_qkv_reshape_op_node
->
Op
()
->
GetAttr
(
"shape"
));
fused_attention_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
fused_attention_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_w_node
,
fuse_qkv_matmul_w
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_out_node
,
fuse_qkv_matmul_out
,
fused_attention_pattern
);
fuse_qkv_matmul_out_node
,
fuse_qkv_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_bias_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_bias_node
,
...
@@ -891,6 +915,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -891,6 +915,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
{
fuse_qkv_ele_add_out_node
->
Name
()});
{
fuse_qkv_ele_add_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"TransposeOut2"
,
fused_attention_op_desc
.
SetOutput
(
"TransposeOut2"
,
{
fuse_qkv_transpose_out_node
->
Name
()});
{
fuse_qkv_transpose_out_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVW"
,
block_id
),
fuse_qkv_matmul_w_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVBias"
,
block_id
),
fuse_qkv_ele_add_bias_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVOut"
,
block_id
),
fuse_qkv_matmul_out_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVBiasOut"
,
block_id
),
fuse_qkv_ele_add_out_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"TransposeOut2"
,
block_id
),
fuse_qkv_transpose_out_node
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
qk_matmul_out_node
,
qk_matmul_out
,
fused_attention_pattern
);
qk_matmul_out_node
,
qk_matmul_out
,
fused_attention_pattern
);
...
@@ -911,12 +950,24 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -911,12 +950,24 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_reshape_out_node
,
qkv_reshape_out
,
fused_attention_pattern
);
qkv_reshape_out_node
,
qkv_reshape_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetOutput
(
"QKOut"
,
{
qk_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKOut"
,
{
qk_matmul_out_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKOut"
,
block_id
),
qk_matmul_out_node
);
fused_attention_op_desc
.
SetInput
(
"SrcMask"
,
fused_attention_op_desc
.
SetInput
(
"SrcMask"
,
{
add_mask_ele_add_mask_node
->
Name
()});
{
add_mask_ele_add_mask_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"SrcMaskOut"
,
fused_attention_op_desc
.
SetOutput
(
"SrcMaskOut"
,
{
add_mask_ele_add_out_node
->
Name
()});
{
add_mask_ele_add_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"SoftmaxOut"
,
fused_attention_op_desc
.
SetOutput
(
"SoftmaxOut"
,
{
qk_softmax_out_node
->
Name
()});
{
qk_softmax_out_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"SrcMask"
,
block_id
),
add_mask_ele_add_mask_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"SrcMaskOut"
,
block_id
),
add_mask_ele_add_out_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"SoftmaxOut"
,
block_id
),
qk_softmax_out_node
);
fused_attention_op_desc
.
SetAttr
(
fused_attention_op_desc
.
SetAttr
(
"attn_dropout_rate"
,
"attn_dropout_rate"
,
PADDLE_GET_CONST
(
float
,
PADDLE_GET_CONST
(
float
,
...
@@ -943,6 +994,18 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -943,6 +994,18 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
fused_attention_op_desc
.
SetOutput
(
"QKTVOut"
,
{
qkv_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"QKTVOut"
,
{
qkv_matmul_out_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"FMHAOut"
,
fused_attention_op_desc
.
SetOutput
(
"FMHAOut"
,
{
qkv_reshape_out_node
->
Name
()});
{
qkv_reshape_out_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"AttnDropoutMaskOut"
,
block_id
),
attn_dropout_mask_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"AttnDropoutOut"
,
block_id
),
attn_dropout_out_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKTVOut"
,
block_id
),
qkv_matmul_out_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"FMHAOut"
,
block_id
),
qkv_reshape_out_node
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_w_node
,
out_linear_matmul_w
,
fused_attention_pattern
);
out_linear_matmul_w_node
,
out_linear_matmul_w
,
fused_attention_pattern
);
...
@@ -952,15 +1015,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -952,15 +1015,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_bias_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_bias_node
,
out_linear_ele_add_bias
,
out_linear_ele_add_bias
,
fused_attention_pattern
);
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_out_node
,
out_linear_ele_add_out
,
fused_attention_pattern
);
fused_attention_op_desc
.
SetInput
(
"OutLinearW"
,
fused_attention_op_desc
.
SetInput
(
"OutLinearW"
,
{
out_linear_matmul_w_node
->
Name
()});
{
out_linear_matmul_w_node
->
Name
()});
fused_attention_op_desc
.
SetInput
(
"OutLinearBias"
,
fused_attention_op_desc
.
SetInput
(
"OutLinearBias"
,
{
out_linear_ele_add_bias_node
->
Name
()});
{
out_linear_ele_add_bias_node
->
Name
()});
fused_attention_op_desc
.
SetOutput
(
"OutLinearOut"
,
fused_attention_op_desc
.
SetOutput
(
"OutLinearOut"
,
{
out_linear_matmul_out_node
->
Name
()});
{
out_linear_matmul_out_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"OutLinearW"
,
block_id
),
out_linear_matmul_w_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"OutLinearBias"
,
block_id
),
out_linear_ele_add_bias_node
);
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"OutLinearOut"
,
block_id
),
out_linear_matmul_out_node
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_dropout_mask_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_dropout_mask_node
,
out_linear_dropout_mask
,
out_linear_dropout_mask
,
fused_attention_pattern
);
fused_attention_pattern
);
...
@@ -983,6 +1052,9 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -983,6 +1052,9 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
"dropout_implementation"
)));
"dropout_implementation"
)));
fused_attention_op_desc
.
SetOutput
(
"DropoutMaskOut"
,
fused_attention_op_desc
.
SetOutput
(
"DropoutMaskOut"
,
{
out_linear_dropout_mask_node
->
Name
()});
{
out_linear_dropout_mask_node
->
Name
()});
cache
->
InsertIntoCache
(
GenerateCacheKey
(
cache_anchor_name
,
"DropoutMaskOut"
,
block_id
),
out_linear_dropout_mask_node
);
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_out_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_out_node
,
residual_ele_add_out
,
residual_ele_add_out
,
...
@@ -1037,6 +1109,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -1037,6 +1109,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
out_linear_ele_add_op_node
,
out_linear_ele_add_op_node
,
out_linear_dropout_op_node
,
out_linear_dropout_op_node
,
residual_ele_add_op_node
});
residual_ele_add_op_node
});
found_fused_attention
++
;
found_fused_attention
++
;
};
};
...
@@ -1046,7 +1119,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
...
@@ -1046,7 +1119,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
return
graph
;
return
graph
;
}
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResBwd
(
Graph
*
graph
)
const
{
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResBwd
(
Graph
*
graph
,
FusedAttentionPassCache
*
cache
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
...
@@ -1067,6 +1141,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
...
@@ -1067,6 +1141,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention backward pass's fusion"
;
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention backward pass's fusion"
;
int
block_id
=
g
->
GetBlockId
();
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_grad_op_node
,
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_grad_op_node
,
residual_ele_add_grad_op
,
residual_ele_add_grad_op
,
fused_attention_grad_pattern
);
fused_attention_grad_pattern
);
...
@@ -1124,7 +1200,302 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
...
@@ -1124,7 +1200,302 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
grad_accumulation_sum_op
,
grad_accumulation_sum_op
,
fused_attention_grad_pattern
);
fused_attention_grad_pattern
);
// TODO(Yuang Liu): finish the handler
OpDesc
fused_attention_grad_op_desc
(
residual_ele_add_grad_op_node
->
Op
()
->
Block
());
fused_attention_grad_op_desc
.
SetType
(
"fused_attention_grad"
);
fused_attention_grad_op_desc
.
SetInput
(
"Y@GRAD"
,
{
subgraph
.
at
(
x
)
->
Name
()});
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_grad_w_node
,
fuse_qkv_matmul_grad_w
,
fused_attention_grad_pattern
);
std
::
string
cache_anchor_name
=
fuse_qkv_matmul_grad_w_node
->
Var
()
->
Name
();
auto
*
x_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"X"
,
block_id
));
auto
*
attn_dropout_mask_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"AttnDropoutMaskOut"
,
block_id
));
auto
*
attn_dropout_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"AttnDropoutOut"
,
block_id
));
auto
*
dropout_mask_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"DropoutMaskOut"
,
block_id
));
auto
*
fmha_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"FMHAOut"
,
block_id
));
auto
*
ln_bias_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnBias"
,
block_id
));
auto
*
ln_mean_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnMean"
,
block_id
));
auto
*
ln_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnOut"
,
block_id
));
auto
*
ln_scale_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnScale"
,
block_id
));
auto
*
ln_variance_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"LnVariance"
,
block_id
));
auto
*
out_linear_bias_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"OutLinearBias"
,
block_id
));
auto
*
out_linear_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"OutLinearOut"
,
block_id
));
auto
*
out_linear_w_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"OutLinearW"
,
block_id
));
auto
*
qk_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKOut"
,
block_id
));
auto
*
qktv_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKTVOut"
,
block_id
));
auto
*
qkv_bias_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVBias"
,
block_id
));
auto
*
qkv_bias_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVBiasOut"
,
block_id
));
auto
*
qkv_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVOut"
,
block_id
));
auto
*
qkv_w_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"QKVW"
,
block_id
));
auto
*
softmax_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"SoftmaxOut"
,
block_id
));
auto
*
src_mask_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"SrcMask"
,
block_id
));
auto
*
src_mask_out_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"SrcMaskOut"
,
block_id
));
auto
*
transpose_out_2_node
=
cache
->
GetNodeFromCache
(
GenerateCacheKey
(
cache_anchor_name
,
"TransposeOut2"
,
block_id
));
fused_attention_grad_op_desc
.
SetInput
(
"X"
,
{
x_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"AttnDropoutMaskOut"
,
{
attn_dropout_mask_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"AttnDropoutOut"
,
{
attn_dropout_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"DropoutMaskOut"
,
{
dropout_mask_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"FMHAOut"
,
{
fmha_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"LnBias"
,
{
ln_bias_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"LnMean"
,
{
ln_mean_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"LnOut"
,
{
ln_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"LnScale"
,
{
ln_scale_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"LnVariance"
,
{
ln_variance_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"OutLinearBias"
,
{
out_linear_bias_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"OutLinearOut"
,
{
out_linear_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"OutLinearW"
,
{
out_linear_w_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"QKOut"
,
{
qk_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"QKTVOut"
,
{
qktv_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"QKVBias"
,
{
qkv_bias_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"QKVBiasOut"
,
{
qkv_bias_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"QKVOut"
,
{
qkv_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"QKVW"
,
{
qkv_w_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"SoftmaxOut"
,
{
softmax_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"SrcMask"
,
{
src_mask_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"SrcMaskOut"
,
{
src_mask_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetInput
(
"TransposeOut2"
,
{
transpose_out_2_node
->
Name
()});
fused_attention_grad_op_desc
.
SetAttr
(
"add_residual"
,
true
);
fused_attention_grad_op_desc
.
SetAttr
(
"attn_dropout_rate"
,
PADDLE_GET_CONST
(
float
,
attn_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"dropout_prob"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"is_test"
,
PADDLE_GET_CONST
(
bool
,
attn_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"is_test"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"attn_dropout_fix_seed"
,
PADDLE_GET_CONST
(
bool
,
attn_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"fix_seed"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"attn_dropout_seed"
,
PADDLE_GET_CONST
(
int
,
attn_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"seed"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"attn_dropout_implementation"
,
PADDLE_GET_CONST
(
std
::
string
,
attn_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"dropout_implementation"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"dropout_rate"
,
PADDLE_GET_CONST
(
float
,
out_linear_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"dropout_prob"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"dropout_fix_seed"
,
PADDLE_GET_CONST
(
bool
,
out_linear_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"fix_seed"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"dropout_seed"
,
PADDLE_GET_CONST
(
int
,
out_linear_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"seed"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"dropout_implementation"
,
PADDLE_GET_CONST
(
std
::
string
,
out_linear_dropout_grad_op_node
->
Op
()
->
GetAttr
(
"dropout_implementation"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"epsilon"
,
PADDLE_GET_CONST
(
float
,
pre_layer_norm_grad_op_node
->
Op
()
->
GetAttr
(
"epsilon"
)));
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
fuse_qkv_reshape_grad_op_node
->
Op
()
->
GetAttr
(
"shape"
));
fused_attention_grad_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
fused_attention_grad_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_attention_grad_op_desc
.
SetAttr
(
"transpose_qkv_wb"
,
true
);
// forward op will use default value
// but backward op has to set these redundant attrs
fused_attention_grad_op_desc
.
SetAttr
(
"ln_epsilon"
,
PADDLE_GET_CONST
(
float
,
pre_layer_norm_grad_op_node
->
Op
()
->
GetAttr
(
"epsilon"
)));
fused_attention_grad_op_desc
.
SetAttr
(
"ring_id"
,
-
1
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_matmul_grad_x_grad_node
,
qkv_matmul_grad_x_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_grad_x_grad_node
,
out_linear_matmul_grad_x_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_grad_bias_grad_node
,
pre_layer_norm_grad_bias_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_grad_x_grad_node
,
fuse_qkv_matmul_grad_x_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_grad_scale_grad_node
,
pre_layer_norm_grad_scale_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_grad_bias_grad_node
,
out_linear_ele_add_grad_bias_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_grad_x_grad_node
,
out_linear_ele_add_grad_x_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_grad_w_grad_node
,
out_linear_matmul_grad_w_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_scale_grad_out_node
,
qk_scale_grad_out
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_transpose_grad_out_node
,
qkv_transpose_grad_out
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_grad_bias_grad_node
,
fuse_qkv_ele_add_grad_bias_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_reshape_grad_out_node
,
fuse_qkv_reshape_grad_out
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_grad_x_grad_node
,
fuse_qkv_ele_add_grad_x_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_grad_w_grad_node
,
fuse_qkv_matmul_grad_w_grad
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_grad_out_node
,
attn_dropout_grad_out
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_softmax_grad_out_node
,
qk_softmax_grad_out
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_split_grad_out_node
,
fuse_qkv_split_grad_out
,
fused_attention_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
grad_accumulation_out_node
,
grad_accumulation_out
,
fused_attention_grad_pattern
);
fused_attention_grad_op_desc
.
SetOutput
(
"AttnDropoutOut@GRAD"
,
{
qkv_matmul_grad_x_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"FMHAOut@GRAD"
,
{
out_linear_matmul_grad_x_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"LnBias@GRAD"
,
{
pre_layer_norm_grad_bias_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"LnOut@GRAD"
,
{
fuse_qkv_matmul_grad_x_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"LnScale@GRAD"
,
{
pre_layer_norm_grad_scale_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"OutLinearBias@GRAD"
,
{
out_linear_ele_add_grad_bias_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"OutLinearOut@GRAD"
,
{
out_linear_ele_add_grad_x_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"OutLinearW@GRAD"
,
{
out_linear_matmul_grad_w_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"QKOut@GRAD"
,
{
qk_scale_grad_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"QKTVOut@GRAD"
,
{
qkv_transpose_grad_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"QKVBias@GRAD"
,
{
fuse_qkv_ele_add_grad_bias_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"QKVBiasOut@GRAD"
,
{
fuse_qkv_reshape_grad_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"QKVOut@GRAD"
,
{
fuse_qkv_ele_add_grad_x_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"QKVW@GRAD"
,
{
fuse_qkv_matmul_grad_w_grad_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"SoftmaxOut@GRAD"
,
{
attn_dropout_grad_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"SrcMaskOut@GRAD"
,
{
qk_softmax_grad_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"TransposeOut2@GRAD"
,
{
fuse_qkv_split_grad_out_node
->
Name
()});
fused_attention_grad_op_desc
.
SetOutput
(
"X@GRAD"
,
{
grad_accumulation_out_node
->
Name
()});
auto
fused_attention_grad_node
=
g
->
CreateOpNode
(
&
fused_attention_grad_op_desc
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
qkv_matmul_grad_x_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
out_linear_matmul_grad_x_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
pre_layer_norm_grad_bias_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
fuse_qkv_matmul_grad_x_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
pre_layer_norm_grad_scale_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
out_linear_ele_add_grad_bias_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
out_linear_ele_add_grad_x_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
out_linear_matmul_grad_w_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
qk_scale_grad_out_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
qkv_transpose_grad_out_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
fuse_qkv_ele_add_grad_bias_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
fuse_qkv_reshape_grad_out_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
fuse_qkv_ele_add_grad_x_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
fuse_qkv_matmul_grad_w_grad_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
attn_dropout_grad_out_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
qk_softmax_grad_out_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
fuse_qkv_split_grad_out_node
);
IR_NODE_LINK_TO
(
fused_attention_grad_node
,
grad_accumulation_out_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
x_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
attn_dropout_mask_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
attn_dropout_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
dropout_mask_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
fmha_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
ln_bias_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
ln_mean_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
ln_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
ln_scale_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
ln_variance_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
out_linear_bias_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
out_linear_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
out_linear_w_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
qk_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
qktv_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
qkv_bias_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
qkv_bias_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
qkv_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
qkv_w_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
softmax_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
src_mask_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
src_mask_out_node
,
fused_attention_grad_node
);
IR_NODE_LINK_TO
(
transpose_out_2_node
,
fused_attention_grad_node
);
GraphSafeRemoveNodes
(
g
,
GraphSafeRemoveNodes
(
g
,
{
residual_ele_add_grad_op_node
,
{
residual_ele_add_grad_op_node
,
...
...
paddle/fluid/framework/ir/fused_attention_pass.h
浏览文件 @
7e8ef328
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
...
@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase {
...
@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase {
}
// namespace patterns
}
// namespace patterns
class
FusedAttentionPassCache
{
public:
ir
::
Node
*
GetNodeFromCache
(
const
std
::
string
name
)
{
if
(
var_name_to_ir_node_cache_
.
count
(
name
))
{
return
var_name_to_ir_node_cache_
.
find
(
name
)
->
second
;
}
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The key (%d) of FusedAttentionCache does not exist."
,
name
));
}
void
InsertIntoCache
(
const
std
::
string
name
,
ir
::
Node
*
node
)
{
if
(
!
var_name_to_ir_node_cache_
.
count
(
name
))
{
var_name_to_ir_node_cache_
.
insert
({
name
,
node
});
}
else
{
PADDLE_THROW
(
platform
::
errors
::
AlreadyExists
(
"The key (%d) of FusedAttentionCache already exist."
,
name
));
}
}
void
ResetCache
()
{
var_name_to_ir_node_cache_
.
clear
();
}
private:
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
var_name_to_ir_node_cache_
;
};
class
FusedAttentionsPass
:
public
FusePassBase
{
class
FusedAttentionsPass
:
public
FusePassBase
{
public:
public:
virtual
~
FusedAttentionsPass
()
{}
virtual
~
FusedAttentionsPass
()
{}
...
@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase {
...
@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part.
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
// If false, the function name won't contain an abbreviation for it.
ir
::
Graph
*
PreMaskDropResFwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResFwd
(
Graph
*
graph
,
FusedAttentionPassCache
*
cache
)
const
;
ir
::
Graph
*
PreMaskDropResBwd
(
Graph
*
graph
,
FusedAttentionPassCache
*
cache
)
const
;
ir
::
Graph
*
PreMaskDropResBwd
(
Graph
*
graph
)
const
;
const
std
::
string
GenerateCacheKey
(
const
std
::
string
anchor
,
const
std
::
string
var_name
,
int
block_id
)
const
{
return
anchor
+
"_"
+
std
::
to_string
(
block_id
)
+
"_"
+
var_name
;
}
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
7e8ef328
...
@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"BiasDropoutResidualOut"
,
AddOutput
(
"BiasDropoutResidualOut"
,
"Result of residual + dropout(src + bias)."
)
"Result of residual + dropout(src + bias)."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"CacheKVOut"
,
"The udpated cache KV."
);
AddOutput
(
"CacheKVOut"
,
"The udpated cache KV."
)
.
AsDispensable
()
;
AddOutput
(
"Y"
,
"Result after attention."
);
AddOutput
(
"Y"
,
"Result after attention."
);
AddAttr
<
int
>
(
"num_heads"
,
"The number head for multi_head_attention."
)
AddAttr
<
int
>
(
"num_heads"
,
"The number head for multi_head_attention."
)
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
浏览文件 @
7e8ef328
...
@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase):
...
@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase):
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
assert
ops
[
6
].
type
==
'fused_attention_grad'
# two ops for linear, one op for reduce mean
# two ops for linear, one op for reduce mean
# one fill constant
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
# the eighth op should be the optimizer
assert
ops
[
8
].
type
==
'sgd'
assert
ops
[
9
].
type
==
'sgd'
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_prog
)
rst
=
exe
.
run
(
main_prog
,
feed
=
{
'x'
:
x_data
,
'attn_mask'
:
mask_data
},
fetch_list
=
[
loss
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录