Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1fbd4440
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1fbd4440
编写于
6月 08, 2022
作者:
W
Wangzheee
提交者:
GitHub
6月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-Inference]support matmulv2 in multihead (#43269)
* support matmulv2 in multihead
上级
e1a34bc4
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
42 addition
and
22 deletion
+42
-22
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+42
-22
未找到文件。
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
浏览文件 @
1fbd4440
...
...
@@ -235,16 +235,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
}
PDNode
*
TrtMultiHeadMatmulPattern
::
operator
()()
{
std
::
unordered_set
<
std
::
string
>
mul_ops
{
"mul"
,
"matmul_v2"
};
std
::
unordered_set
<
std
::
string
>
matmul_ops
{
"matmul"
,
"matmul_v2"
};
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op
_input
(
"mul"
);
input0
->
assert_is_op
s_input
(
mul_ops
);
// First path with scale
auto
*
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_op
s
(
mul_ops
);
auto
*
mul0_w_var
=
pattern
->
NewNode
(
mul0_w_repr
())
->
AsInput
()
->
assert_is_op
_input
(
"mul"
,
"Y"
);
->
assert_is_op
s_input
(
mul_ops
,
"Y"
);
auto
*
mul0_out_var
=
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_op
_output
(
"mul"
);
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_op
s_output
(
mul_ops
);
decltype
(
mul0
)
eltadd0
;
decltype
(
mul0
)
eltadd0_b_var
;
...
...
@@ -277,11 +279,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
auto
*
scale
=
pattern
->
NewNode
(
scale_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_out_var
=
pattern
->
NewNode
(
scale_out_repr
())
->
assert_is_op_output
(
"scale"
);
scale_out_var
->
AsIntermediate
()
->
assert_is_op
_input
(
"matmul"
);
scale_out_var
->
AsIntermediate
()
->
assert_is_op
s_input
(
matmul_ops
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op
_output
(
"matmul"
);
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op
s_output
(
matmul_ops
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
...
...
@@ -297,12 +300,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
);
softmax_qk_out_var
->
AsIntermediate
()
->
assert_is_op
_input
(
"matmul"
);
softmax_qk_out_var
->
AsIntermediate
()
->
assert_is_op
s_input
(
matmul_ops
);
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul"
);
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
s
(
matmul_ops
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op
_output
(
"matmul"
);
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op
s_output
(
matmul_ops
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
...
...
@@ -315,15 +318,15 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_qkv_out_var
->
assert_is_op
_input
(
"mul"
);
reshape2_qkv_out_var
->
assert_is_op
s_input
(
mul_ops
);
// Second path to matmul
auto
*
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_op
s
(
mul_ops
);
auto
*
mul1_w_var
=
pattern
->
NewNode
(
mul1_w_repr
())
->
AsInput
()
->
assert_is_op
_input
(
"mul"
,
"Y"
);
->
assert_is_op
s_input
(
mul_ops
,
"Y"
);
auto
*
mul1_out_var
=
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_op
_output
(
"mul"
);
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_op
s_output
(
mul_ops
);
decltype
(
mul1
)
eltadd1
;
decltype
(
mul1
)
eltadd1_b_var
;
...
...
@@ -350,16 +353,16 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
);
// link to matmul qk
transpose2_1_out_var
->
AsIntermediate
()
->
assert_is_op
s
_input
(
matmul_ops
);
// link to matmul qk
// Third path to matmul
auto
*
mul2
=
pattern
->
NewNode
(
mul2_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul2
=
pattern
->
NewNode
(
mul2_repr
())
->
assert_is_op
s
(
mul_ops
);
auto
*
mul2_w_var
=
pattern
->
NewNode
(
mul2_w_repr
())
->
AsInput
()
->
assert_is_op
_input
(
"mul"
,
"Y"
);
->
assert_is_op
s_input
(
mul_ops
,
"Y"
);
auto
*
mul2_out_var
=
pattern
->
NewNode
(
mul2_out_repr
())
->
assert_is_op
_output
(
"mul"
);
pattern
->
NewNode
(
mul2_out_repr
())
->
assert_is_op
s_output
(
mul_ops
);
decltype
(
mul2
)
eltadd2
;
decltype
(
mul2
)
eltadd2_b_var
;
...
...
@@ -386,8 +389,8 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
);
// link to matmul qkv
transpose2_2_out_var
->
AsIntermediate
()
->
assert_is_op
s
_input
(
matmul_ops
);
// link to matmul qkv
// Q path
mul0
->
LinksFrom
({
input0
,
mul0_w_var
}).
LinksTo
({
mul0_out_var
});
...
...
@@ -734,6 +737,23 @@ TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() {
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
...
...
@@ -866,7 +886,7 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
auto
*
mul0_op_desc
=
mul0
->
Op
();
// all mul op has same input.
if
(
mul
tihead_op_desc
.
HasAttr
(
"Input_scale"
))
{
if
(
mul
0_op_desc
->
HasAttr
(
"Input_scale"
))
{
multihead_op_desc
.
SetAttr
(
"Input_scale"
,
mul0_op_desc
->
GetAttr
(
"Input_scale"
));
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录