Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7e8ef328
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7e8ef328
编写于
2月 03, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
2月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fused attention pass backward op replace. (#50186)
上级
f2ec69b4
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
428 addition
and
14 deletion
+428
-14
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+381
-10
paddle/fluid/framework/ir/fused_attention_pass.h
paddle/fluid/framework/ir/fused_attention_pass.h
+36
-2
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+1
-1
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+10
-1
未找到文件。
paddle/fluid/framework/ir/fused_attention_pass.cc
浏览文件 @
7e8ef328
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/fused_attention_pass.h
浏览文件 @
7e8ef328
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
...
@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase {
...
@@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase {
}
// namespace patterns
}
// namespace patterns
class
FusedAttentionPassCache
{
public:
ir
::
Node
*
GetNodeFromCache
(
const
std
::
string
name
)
{
if
(
var_name_to_ir_node_cache_
.
count
(
name
))
{
return
var_name_to_ir_node_cache_
.
find
(
name
)
->
second
;
}
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The key (%d) of FusedAttentionCache does not exist."
,
name
));
}
void
InsertIntoCache
(
const
std
::
string
name
,
ir
::
Node
*
node
)
{
if
(
!
var_name_to_ir_node_cache_
.
count
(
name
))
{
var_name_to_ir_node_cache_
.
insert
({
name
,
node
});
}
else
{
PADDLE_THROW
(
platform
::
errors
::
AlreadyExists
(
"The key (%d) of FusedAttentionCache already exist."
,
name
));
}
}
void
ResetCache
()
{
var_name_to_ir_node_cache_
.
clear
();
}
private:
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
var_name_to_ir_node_cache_
;
};
class
FusedAttentionsPass
:
public
FusePassBase
{
class
FusedAttentionsPass
:
public
FusePassBase
{
public:
public:
virtual
~
FusedAttentionsPass
()
{}
virtual
~
FusedAttentionsPass
()
{}
...
@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase {
...
@@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part.
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
// If false, the function name won't contain an abbreviation for it.
ir
::
Graph
*
PreMaskDropResFwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResFwd
(
Graph
*
graph
,
FusedAttentionPassCache
*
cache
)
const
;
ir
::
Graph
*
PreMaskDropResBwd
(
Graph
*
graph
,
FusedAttentionPassCache
*
cache
)
const
;
ir
::
Graph
*
PreMaskDropResBwd
(
Graph
*
graph
)
const
;
const
std
::
string
GenerateCacheKey
(
const
std
::
string
anchor
,
const
std
::
string
var_name
,
int
block_id
)
const
{
return
anchor
+
"_"
+
std
::
to_string
(
block_id
)
+
"_"
+
var_name
;
}
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
7e8ef328
...
@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"BiasDropoutResidualOut"
,
AddOutput
(
"BiasDropoutResidualOut"
,
"Result of residual + dropout(src + bias)."
)
"Result of residual + dropout(src + bias)."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"CacheKVOut"
,
"The udpated cache KV."
);
AddOutput
(
"CacheKVOut"
,
"The udpated cache KV."
)
.
AsDispensable
()
;
AddOutput
(
"Y"
,
"Result after attention."
);
AddOutput
(
"Y"
,
"Result after attention."
);
AddAttr
<
int
>
(
"num_heads"
,
"The number head for multi_head_attention."
)
AddAttr
<
int
>
(
"num_heads"
,
"The number head for multi_head_attention."
)
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
浏览文件 @
7e8ef328
...
@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase):
...
@@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase):
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
2
].
type
==
'fused_attention'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
3
].
type
==
'reduce_mean'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
assert
ops
[
5
].
type
==
'reduce_mean_grad'
assert
ops
[
6
].
type
==
'fused_attention_grad'
# two ops for linear, one op for reduce mean
# two ops for linear, one op for reduce mean
# one fill constant
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
# the eighth op should be the optimizer
assert
ops
[
8
].
type
==
'sgd'
assert
ops
[
9
].
type
==
'sgd'
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_prog
)
rst
=
exe
.
run
(
main_prog
,
feed
=
{
'x'
:
x_data
,
'attn_mask'
:
mask_data
},
fetch_list
=
[
loss
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录