Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fcec564c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fcec564c
编写于
2月 06, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
2月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fused attn pass single ut (#50227)
上级
8fb2dce9
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
71 addition
and
57 deletion
+71
-57
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+29
-29
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+42
-28
未找到文件。
paddle/fluid/framework/ir/fused_attention_pass.cc
浏览文件 @
fcec564c
...
...
@@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
fuse_qkv_split_out_v_node
});
// core attention pattern
auto
*
qk_scale_node
=
pattern
->
NewNode
(
qk_scale_op_repr
())
->
assert_is_op
(
"scale"
);
auto
*
qk_scale_out_node
=
pattern
->
NewNode
(
qk_scale_out_repr
())
->
assert_is_op_output
(
"scale"
);
fuse_qkv_split_out_q_node
->
assert_is_op_input
(
"scale"
,
"X"
);
qk_scale_node
->
LinksFrom
({
fuse_qkv_split_out_q_node
})
.
LinksTo
({
qk_scale_out_node
});
auto
*
qk_matmul_node
=
pattern
->
NewNode
(
qk_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
qk_matmul_out_node
=
pattern
->
NewNode
(
qk_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
fuse_qkv_split_out_q
_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
qk_scale_out
_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
fuse_qkv_split_out_k_node
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
qk_matmul_node
->
LinksFrom
({
fuse_qkv_split_out_q_node
,
fuse_qkv_split_out_k_node
})
qk_matmul_node
->
LinksFrom
({
qk_scale_out_node
,
fuse_qkv_split_out_k_node
})
.
LinksTo
({
qk_matmul_out_node
});
auto
*
qk_scale_node
=
pattern
->
NewNode
(
qk_scale_op_repr
())
->
assert_is_op
(
"scale"
);
auto
*
qk_scale_out_node
=
pattern
->
NewNode
(
qk_scale_out_repr
())
->
assert_is_op_output
(
"scale"
);
qk_matmul_out_node
->
assert_is_op_input
(
"scale"
,
"X"
);
qk_scale_node
->
LinksFrom
({
qk_matmul_out_node
}).
LinksTo
({
qk_scale_out_node
});
PDNode
*
add_mask_ele_add_out_node
{
nullptr
};
if
(
has_attn_mask
)
{
auto
*
add_mask_ele_add_node
=
pattern
->
NewNode
(
add_mask_ele_add_op_repr
())
...
...
@@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
add_mask_ele_add_out_node
=
pattern
->
NewNode
(
add_mask_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
qk_
scale
_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
qk_
matmul
_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
add_mask_ele_add_node
->
LinksFrom
({
qk_
scale
_out_node
,
add_mask_ele_add_mask_node
})
->
LinksFrom
({
qk_
matmul
_out_node
,
add_mask_ele_add_mask_node
})
.
LinksTo
({
add_mask_ele_add_out_node
});
}
...
...
@@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
qk_softmax_node
->
LinksFrom
({
add_mask_ele_add_out_node
})
.
LinksTo
({
qk_softmax_out_node
});
}
else
{
qk_
scale
_out_node
->
assert_is_op_input
(
"softmax"
,
"X"
);
qk_softmax_node
->
LinksFrom
({
qk_
scale
_out_node
})
qk_
matmul
_out_node
->
assert_is_op_input
(
"softmax"
,
"X"
);
qk_softmax_node
->
LinksFrom
({
qk_
matmul
_out_node
})
.
LinksTo
({
qk_softmax_out_node
});
}
...
...
@@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
.
LinksTo
({
add_mask_ele_add_grad_x_grad_node
});
}
PDNode
*
qk_
scale
_grad_input_node
=
PDNode
*
qk_
matmul
_grad_input_node
=
has_attn_mask
?
add_mask_ele_add_grad_x_grad_node
:
qk_softmax_grad_out
;
auto
*
qk_scale_grad_node
=
pattern
->
NewNode
(
qk_scale_grad_op_repr
())
->
assert_is_op
(
"scale"
);
auto
*
qk_scale_grad_out_node
=
pattern
->
NewNode
(
qk_scale_grad_out_repr
())
->
assert_is_op_output
(
"scale"
);
qk_scale_grad_input_node
->
assert_is_op_input
(
"scale"
,
"X"
);
qk_scale_grad_node
->
LinksFrom
({
qk_scale_grad_input_node
})
.
LinksTo
({
qk_scale_grad_out_node
});
auto
*
qk_matmul_grad_node
=
pattern
->
NewNode
(
qk_matmul_grad_op_repr
())
->
assert_is_op
(
"matmul_v2_grad"
);
auto
*
qk_matmul_grad_x_node
=
pattern
->
NewNode
(
qk_matmul_grad_x_repr
())
...
...
@@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
auto
*
qk_matmul_grad_w_grad_node
=
pattern
->
NewNode
(
qk_matmul_grad_w_grad_repr
())
->
assert_is_op_output
(
"matmul_v2_grad"
,
"Y@GRAD"
);
qk_
scale_grad_o
ut_node
->
assert_is_op_input
(
"matmul_v2_grad"
,
"Out@GRAD"
);
qk_
matmul_grad_inp
ut_node
->
assert_is_op_input
(
"matmul_v2_grad"
,
"Out@GRAD"
);
qk_matmul_grad_node
->
LinksFrom
({
qk_
scale_grad_o
ut_node
,
->
LinksFrom
({
qk_
matmul_grad_inp
ut_node
,
qk_matmul_grad_x_node
,
qk_matmul_grad_w_node
})
.
LinksTo
({
qk_matmul_grad_x_grad_node
,
qk_matmul_grad_w_grad_node
});
auto
*
qk_scale_grad_node
=
pattern
->
NewNode
(
qk_scale_grad_op_repr
())
->
assert_is_op
(
"scale"
);
auto
*
qk_scale_grad_out_node
=
pattern
->
NewNode
(
qk_scale_grad_out_repr
())
->
assert_is_op_output
(
"scale"
);
qk_matmul_grad_x_grad_node
->
assert_is_op_input
(
"scale"
,
"X"
);
qk_scale_grad_node
->
LinksFrom
({
qk_matmul_grad_x_grad_node
})
.
LinksTo
({
qk_scale_grad_out_node
});
// fuse qkv projection
auto
*
fuse_qkv_split_grad_node
=
pattern
->
NewNode
(
fuse_qkv_split_grad_op_repr
())
->
assert_is_op
(
"concat"
);
auto
*
fuse_qkv_split_grad_out_node
=
pattern
->
NewNode
(
fuse_qkv_split_grad_out_repr
())
->
assert_is_op_output
(
"concat"
);
qk_
matmul_grad_x_grad_node
->
assert_is_op_input
(
"concat"
);
// q grad
qk_
scale_grad_out_node
->
assert_is_op_input
(
"concat"
);
// q grad
qk_matmul_grad_w_grad_node
->
assert_is_op_input
(
"concat"
);
// k grad
qkv_matmul_grad_w_grad_node
->
assert_is_op_input
(
"concat"
);
// v grad
fuse_qkv_split_grad_node
->
LinksFrom
({
qk_
matmul_grad_x_grad
_node
,
->
LinksFrom
({
qk_
scale_grad_out
_node
,
qk_matmul_grad_w_grad_node
,
qkv_matmul_grad_w_grad_node
})
.
LinksTo
({
fuse_qkv_split_grad_out_node
});
...
...
@@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
fused_attention_op_desc
.
SetAttr
(
"transpose_qkv_wb"
,
true
);
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
fuse_qkv_reshape_op_node
->
Op
()
->
GetAttr
(
"shape"
));
fused_attention_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
fused_attention_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]
/
3
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_out_node
,
fuse_qkv_matmul_out
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_bias_node
,
...
...
@@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
std
::
vector
<
int
>
shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
fuse_qkv_reshape_grad_op_node
->
Op
()
->
GetAttr
(
"shape"
));
fused_attention_grad_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]);
fused_attention_grad_op_desc
.
SetAttr
(
"num_heads"
,
shape
[
2
]
/
3
);
fused_attention_grad_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_attention_grad_op_desc
.
SetAttr
(
"transpose_qkv_wb"
,
true
);
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
浏览文件 @
fcec564c
...
...
@@ -53,7 +53,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self
.
qkv_proj
=
paddle
.
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
)
self
.
out_proj
=
paddle
.
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
dropout
=
paddle
.
nn
.
Dropout
(
0.1
,
mode
=
"upscale_in_train"
)
self
.
dropout
=
paddle
.
nn
.
Dropout
(
1e-10
,
mode
=
"upscale_in_train"
)
def
forward
(
self
,
x
,
attn_mask
=
None
):
residual
=
x
...
...
@@ -64,13 +64,13 @@ class MultiHeadAttention(paddle.nn.Layer):
# compute qkv
qkv
=
self
.
qkv_proj
(
x
)
qkv
=
paddle
.
reshape
(
qkv
,
[
0
,
0
,
self
.
num_heads
,
3
*
self
.
head_dim
])
qkv
=
paddle
.
reshape
(
qkv
,
[
0
,
0
,
3
*
self
.
num_heads
,
self
.
head_dim
])
qkv
=
paddle
.
transpose
(
qkv
,
[
0
,
2
,
1
,
3
])
q
,
k
,
v
=
paddle
.
split
(
qkv
,
num_or_sections
=
3
,
axis
=
-
1
)
q
,
k
,
v
=
paddle
.
split
(
qkv
,
num_or_sections
=
3
,
axis
=
1
)
# compute core attention
q
=
paddle
.
scale
(
q
,
scale
=
self
.
head_dim
**-
0.5
)
product
=
paddle
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
)
product
=
paddle
.
scale
(
product
,
scale
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
...
...
@@ -104,21 +104,28 @@ class TestFusedAttentionPass(unittest.TestCase):
self
.
pre_ln
=
True
self
.
attn_dropout
=
True
self
.
add_mask
=
True
self
.
x_data
=
None
self
.
mask_data
=
None
def
test_pass
(
self
):
def
get_rst
(
self
,
use_pass
=
False
):
batch_size
=
2
seq_len
=
1024
hidden_size
=
768
num_heads
=
12
x_data
=
np
.
random
.
rand
(
batch_size
,
seq_len
,
seq_len
).
astype
(
'float32'
)
mask_data
=
np
.
random
.
rand
(
batch_size
,
num_heads
,
seq_len
,
seq_len
).
astype
(
'float32'
)
np
.
random
.
seed
(
1234
)
if
self
.
x_data
is
None
:
self
.
x_data
=
np
.
random
.
rand
(
batch_size
,
seq_len
,
seq_len
).
astype
(
'float32'
)
self
.
mask_data
=
np
.
random
.
rand
(
batch_size
,
num_heads
,
seq_len
,
seq_len
).
astype
(
'float32'
)
main_prog
=
paddle
.
static
.
Program
()
main_prog
.
random_seed
=
1234
startup_prog
=
paddle
.
static
.
Program
()
startup_prog
.
random_seed
=
1234
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
data
=
paddle
.
static
.
data
(
...
...
@@ -150,29 +157,36 @@ class TestFusedAttentionPass(unittest.TestCase):
sgd_optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
loss
)
pass_manager
=
PassManager
([
new_pass
(
"fused_attention"
)])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
ops
=
main_prog
.
global_block
().
ops
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
assert
ops
[
6
].
type
==
'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert
ops
[
9
].
type
==
'sgd'
if
use_pass
:
pass_manager
=
PassManager
([
new_pass
(
"fused_attention"
)])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
ops
=
main_prog
.
global_block
().
ops
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
assert
ops
[
6
].
type
==
'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert
ops
[
9
].
type
==
'sgd'
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_prog
)
rst
=
exe
.
run
(
main_prog
,
feed
=
{
'x'
:
x_data
,
'attn_mask'
:
mask_data
},
fetch_list
=
[
loss
],
)
for
i
in
range
(
2
):
rst
=
exe
.
run
(
main_prog
,
feed
=
{
'x'
:
self
.
x_data
,
'attn_mask'
:
self
.
mask_data
},
fetch_list
=
[
loss
],
)
return
rst
def
test_pass
(
self
):
fused_rst
=
self
.
get_rst
(
use_pass
=
True
)
non_fused_rst
=
self
.
get_rst
()
assert
np
.
allclose
(
fused_rst
,
non_fused_rst
)
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录