Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
178b2440
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
178b2440
编写于
6月 23, 2022
作者:
W
Wangzheee
提交者:
GitHub
6月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
general_prelayernorm_transformer (#43748)
上级
dbf92d49
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
511 addition
and
199 deletion
+511
-199
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
...amework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
+43
-28
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
...ramework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
+3
-1
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
+18
-12
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+158
-19
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+23
-0
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+224
-118
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
+8
-6
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+34
-15
未找到文件。
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
@@ -31,7 +31,8 @@ namespace framework {
...
@@ -31,7 +31,8 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
,
const
std
::
string
&
arg
,
bool
is_persist
=
false
)
{
bool
is_persist
=
false
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
...
@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
...
@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
if
(
is_persist
)
return
node
->
assert_is_persistable_var
();
if
(
is_persist
)
return
node
->
assert_is_persistable_var
();
return
node
;
return
node
;
}
}
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
)
{
const
std
::
string
&
arg
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
"lookup_table_v2"
};
...
@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
...
@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars
(
pattern
,
lookup_table2_w_repr
(),
"W"
,
true
);
create_emb_vars
(
pattern
,
lookup_table2_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
feed2
=
pattern
->
NewNode
(
feed2_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table2
=
auto
*
lookup_table2
=
...
@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
...
@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
->
assert_is_op_output
(
"elementwise_add"
);
feed1
->
LinksTo
({
lookup_table1_x
});
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
.
LinksTo
({
lookup_table1_out
});
feed2
->
LinksTo
({
lookup_table2_x
});
lookup_table2
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
lookup_table2
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
.
LinksTo
({
lookup_table2_out
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
...
@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
...
@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars
(
pattern
,
lookup_table1_w_repr
(),
"W"
,
true
);
create_emb_vars
(
pattern
,
lookup_table1_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table1_out
=
auto
*
lookup_table1_out
=
...
@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
...
@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
->
assert_is_op_output
(
"elementwise_add"
);
->
assert_is_op_output
(
"elementwise_add"
);
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
.
LinksTo
({
lookup_table1_out
});
feed1
->
LinksTo
({
lookup_table1_x
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
eltwise_add_in
})
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
eltwise_add_in
})
.
LinksTo
({
eltwise_add_out
});
.
LinksTo
({
eltwise_add_out
});
}
}
...
@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
...
@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_w
,
lookup_table2_w
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_w
,
lookup_table2_w
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2
,
lookup_table2
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2
,
lookup_table2
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
start_pattern
);
lookup_table1_out
,
lookup_table1_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_out
,
lookup_table2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
start_pattern
);
lookup_table2_out
,
lookup_table2_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
start_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
...
@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
...
@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
start_pattern_out_node
.
push_back
(
eltwise_add_out
);
start_pattern_out_node
.
push_back
(
eltwise_add_out
);
std
::
unordered_set
<
Node
*>
rm_nodes
;
std
::
unordered_set
<
Node
*>
rm_nodes
;
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2
,
lookup_table1_out
,
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
lookup_table2
,
lookup_table1_out
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
start_pattern_remove_nodes
.
push_back
(
rm_nodes
);
start_pattern_remove_nodes
.
push_back
(
rm_nodes
);
};
};
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
...
@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
...
@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_x
,
lookup_table1_x
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_x
,
lookup_table1_x
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_w
,
lookup_table1_w
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_w
,
lookup_table1_w
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
second_pattern
);
lookup_table1_out
,
lookup_table1_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_in
,
eltwise_add_in
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_in
,
eltwise_add_in
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
second_pattern
);
...
@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
...
@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_pattern
);
eltwise_add_out
,
eltwise_add_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_pattern
);
layer_norm_out
,
layer_norm_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_pattern
);
layer_norm_bias
,
layer_norm_bias
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_pattern
);
layer_norm_scale
,
layer_norm_scale
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_pattern
);
layer_norm_mean
,
layer_norm_mean
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_pattern
);
layer_norm_variance
,
layer_norm_variance
,
skip_layernorm_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"Pass(PrelnSkipLayerNorm) in op compat failed."
;
LOG
(
WARNING
)
<<
"Pass(PrelnSkipLayerNorm) in op compat failed."
;
return
;
return
;
...
@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
...
@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs
.
push_back
(
inner_pattern_ins
[
js
[
iter
]].
second
->
Name
());
embs
.
push_back
(
inner_pattern_ins
[
js
[
iter
]].
second
->
Name
());
}
}
OpDesc
new_op_desc
;
OpDesc
new_op_desc
(
end_patter_layernorms
[
0
]
->
Op
()
->
Block
())
;
new_op_desc
.
SetType
(
"fused_preln_embedding_eltwise_layernorm"
);
new_op_desc
.
SetType
(
"fused_preln_embedding_eltwise_layernorm"
);
new_op_desc
.
SetInput
(
"Ids"
,
ids
);
new_op_desc
.
SetInput
(
"Ids"
,
ids
);
new_op_desc
.
SetInput
(
"Embs"
,
embs
);
new_op_desc
.
SetInput
(
"Embs"
,
embs
);
...
@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
...
@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
with_dynamic_shape
))
{
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
VLOG
(
4
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
pos_id
!=
""
&&
"enable_int8, "
mask_id
!=
""
&&
with_dynamic_shape
))
{
VLOG
(
3
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, set pos_id, set mask_id, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"pass, "
"please reconfig."
;
"please reconfig."
;
return
;
return
;
}
}
int
fusion_count
=
int
fusion_count
=
PrelnEmbeddingEltwiseLayerNormFusePass
::
BuildFusion
(
graph
,
name_scope_
);
PrelnEmbeddingEltwiseLayerNormFusePass
::
BuildFusion
(
graph
,
name_scope_
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
...
...
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
浏览文件 @
178b2440
...
@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
...
@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding2_eltwise1"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding2_eltwise1"
)
{}
void
operator
()();
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
feed2
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
...
@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
...
@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
const
std
::
string
&
name_scope
)
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding1_eltwise1"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding1_eltwise1"
)
{}
void
operator
()();
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1
);
PATTERN_DECL_NODE
(
lookup_table1
);
...
...
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
)
&&
VLOG
(
4
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
)
&&
pos_id
!=
""
&&
"use_varseqlen, "
mask_id
!=
""
&&
with_dynamic_shape
))
{
"with_interleaved, with_dynamic_shape. Stop this pass, please "
VLOG
(
3
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"reconfig. "
;
"with_interleaved"
"use_varseqlen, preln_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass"
"set pos_id, set mask_id, with_dynamic_shape. Stop this pass, "
"please "
"reconfig."
;
return
;
return
;
}
}
int
found_subgraph_count
=
0
;
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
...
@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_pattern
);
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_pattern
);
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an PrelnSkipLayerNorm op node
// Create an PrelnSkipLayerNorm op node
OpDesc
new_desc
;
OpDesc
new_desc
(
elementwise
->
Op
()
->
Block
())
;
new_desc
.
SetType
(
"preln_skip_layernorm"
);
new_desc
.
SetType
(
"preln_skip_layernorm"
);
// inputs
// inputs
...
@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
found_subgraph_count
++
;
found_subgraph_count
++
;
};
};
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
AddStatis
(
found_subgraph_count
);
}
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
浏览文件 @
178b2440
...
@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() {
...
@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() {
emb_elt_layernorm_op
->
LinksTo
({
emb_elt_layernorm_out
});
emb_elt_layernorm_op
->
LinksTo
({
emb_elt_layernorm_out
});
}
}
void
PrelnEmbEltwiseLayernorm
::
operator
()()
{
// Create nodes for fused_preln_embedding_eltwise_layernorm.
auto
*
preln_emb_elt_layernorm_op
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_op_repr
())
->
assert_is_op
(
"fused_preln_embedding_eltwise_layernorm"
);
auto
*
preln_emb_elt_layernorm_out_0
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_out_0_repr
())
->
assert_is_op_output
(
"fused_preln_embedding_eltwise_layernorm"
,
"Out_0"
);
auto
*
preln_emb_elt_layernorm_out_1
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_out_1_repr
())
->
assert_is_op_output
(
"fused_preln_embedding_eltwise_layernorm"
,
"Out_1"
);
// Add links for fused_preln_embedding_eltwise_layernorm op.
preln_emb_elt_layernorm_op
->
LinksTo
(
{
preln_emb_elt_layernorm_out_0
,
preln_emb_elt_layernorm_out_1
});
}
void
SkipLayernorm
::
operator
()()
{
void
SkipLayernorm
::
operator
()()
{
// Create nodes for skip_layernorm.
// Create nodes for skip_layernorm.
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
...
@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() {
...
@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() {
.
LinksTo
({
skip_layernorm_out
});
.
LinksTo
({
skip_layernorm_out
});
}
}
void
PrelnSkipLayernorm
::
operator
()()
{
// Create nodes for preln_skip_layernorm.
auto
*
preln_skip_layernorm_x
=
pattern
->
NewNode
(
preln_skip_layernorm_x_repr
())
->
assert_is_op_input
(
"preln_skip_layernorm"
,
"X"
);
auto
*
preln_skip_layernorm_y
=
pattern
->
NewNode
(
preln_skip_layernorm_y_repr
())
->
assert_is_op_input
(
"preln_skip_layernorm"
,
"Y"
);
auto
*
preln_skip_layernorm_op
=
pattern
->
NewNode
(
preln_skip_layernorm_op_repr
())
->
assert_is_op
(
"preln_skip_layernorm"
);
auto
*
preln_skip_layernorm_out_0
=
pattern
->
NewNode
(
preln_skip_layernorm_out_0_repr
())
->
assert_is_op_output
(
"preln_skip_layernorm"
,
"Out_0"
);
auto
*
preln_skip_layernorm_out_1
=
pattern
->
NewNode
(
preln_skip_layernorm_out_1_repr
())
->
assert_is_op_output
(
"preln_skip_layernorm"
,
"Out_1"
);
// Add links for preln_skip_layernorm op.
preln_skip_layernorm_op
->
LinksFrom
({
preln_skip_layernorm_x
,
preln_skip_layernorm_y
})
.
LinksTo
({
preln_skip_layernorm_out_0
,
preln_skip_layernorm_out_1
});
}
void
MultiheadMatmul
::
operator
()()
{
void
MultiheadMatmul
::
operator
()()
{
// Create nodes for multihead_matmul.
// Create nodes for multihead_matmul.
auto
*
multihead_matmul_input
=
auto
*
multihead_matmul_input
=
...
@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen remove_padding_recover_padding_pass"
;
VLOG
(
3
)
<<
"start varseqlen remove_padding_recover_padding_pass"
;
}
else
{
}
else
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass check failed"
;
return
;
return
;
}
}
...
@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
// set out_threshold for int8
// set out_threshold for int8
if
(
op_node
->
Op
()
->
HasAttr
(
"
out_threshold
"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"
Input_scale
"
))
{
remove_padding
.
SetAttr
(
"out_threshold"
,
remove_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
op_node
->
Op
()
->
GetAttr
(
"Input_scale"
));
}
else
{
VLOG
(
3
)
<<
"remove_padding_op has not out_threshold, because next op has "
"not Input_scale."
;
}
}
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
...
@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
if
(
op_node
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_0_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_0_threshold"
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_1_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_1_threshold"
));
}
else
{
VLOG
(
3
)
<<
"recover_padding_op has not out_threshold, because previous "
"op has not out_*_threshold."
;
}
}
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
...
@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_embedding_eltwise_layernorm"
;
"fused_embedding_eltwise_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
fused_embedding_eltwise_layernorm
);
fused_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
fused_embedding_eltwise_layernorm
);
fused_embedding_eltwise_layernorm
);
insert_recover_padding_op
(
emb_elt_layernorm_op
,
emb_elt_layernorm_out
);
insert_recover_padding_op
(
emb_elt_layernorm_op
,
emb_elt_layernorm_out
);
...
@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"
;
"multihead_matmul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul
);
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul
);
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul
);
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
multihead_matmul_input_shape
=
multihead_matmul_input
->
Var
()
->
GetShape
();
multihead_matmul_input_shape
=
multihead_matmul_input
->
Var
()
->
GetShape
();
...
@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"
;
"skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm
);
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm
);
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm
);
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm
);
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
skip_layernorm_x
->
Var
()
->
GetShape
();
skip_layernorm_x
->
Var
()
->
GetShape
();
...
@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
};
};
gpd4
(
graph
,
handler4
);
gpd4
(
graph
,
handler4
);
GraphPatternDetector
gpd5
;
patterns
::
PrelnEmbEltwiseLayernorm
fused_preln_embedding_eltwise_layernorm
(
gpd5
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fused_preln_embedding_eltwise_layernorm
();
auto
handler5
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_preln_embedding_eltwise_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_op
,
fused_preln_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_out_0
,
preln_emb_elt_layernorm_out_0
,
fused_preln_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_out_1
,
preln_emb_elt_layernorm_out_1
,
fused_preln_embedding_eltwise_layernorm
);
insert_recover_padding_op
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_out_0
);
insert_recover_padding_op
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_out_1
);
found_subgraph_count
++
;
};
gpd5
(
graph
,
handler5
);
GraphPatternDetector
gpd6
;
patterns
::
PrelnSkipLayernorm
preln_skip_layernorm
(
gpd6
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
preln_skip_layernorm
();
auto
handler6
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"preln_skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_x
,
preln_skip_layernorm_x
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_y
,
preln_skip_layernorm_y
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_op
,
preln_skip_layernorm_op
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_out_0
,
preln_skip_layernorm_out_0
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_out_1
,
preln_skip_layernorm_out_1
,
preln_skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
preln_skip_layernorm_x
->
Var
()
->
GetShape
();
if
(
skip_layernorm_x_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
for
(
size_t
i
=
0
;
i
<
skip_layernorm_x_shape
.
size
();
++
i
)
{
if
(
skip_layernorm_x_shape
[
i
]
!=
multihead_matmul_input_shape
[
i
])
{
check_flag
=
false
;
}
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
insert_remove_padding_op
(
preln_skip_layernorm_x
,
preln_skip_layernorm_op
);
insert_remove_padding_op
(
preln_skip_layernorm_y
,
preln_skip_layernorm_op
);
insert_recover_padding_op
(
preln_skip_layernorm_op
,
preln_skip_layernorm_out_0
);
insert_recover_padding_op
(
preln_skip_layernorm_op
,
preln_skip_layernorm_out_1
);
found_subgraph_count
++
;
};
gpd6
(
graph
,
handler6
);
AddStatis
(
found_subgraph_count
);
AddStatis
(
found_subgraph_count
);
}
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
浏览文件 @
178b2440
...
@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase {
...
@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase {
PATTERN_DECL_NODE
(
emb_elt_layernorm_out
);
PATTERN_DECL_NODE
(
emb_elt_layernorm_out
);
};
};
struct
PrelnEmbEltwiseLayernorm
:
public
PatternBase
{
PrelnEmbEltwiseLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_emb_elt_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_op
);
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_out_0
);
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_out_1
);
};
struct
SkipLayernorm
:
public
PatternBase
{
struct
SkipLayernorm
:
public
PatternBase
{
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
...
@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase {
...
@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase {
PATTERN_DECL_NODE
(
skip_layernorm_out
);
PATTERN_DECL_NODE
(
skip_layernorm_out
);
};
};
struct
PrelnSkipLayernorm
:
public
PatternBase
{
PrelnSkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_skip_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
preln_skip_layernorm_x
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_y
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_op
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_out_0
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_out_1
);
};
struct
MultiheadMatmul
:
public
PatternBase
{
struct
MultiheadMatmul
:
public
PatternBase
{
MultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
MultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
...
...
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
浏览文件 @
178b2440
...
@@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
...
@@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
multihead_pattern
();
multihead_pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
mul1
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
mul2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
)
{
Node
*
scale_out
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
...
@@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
...
@@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
...
@@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
...
@@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
...
@@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
...
@@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
fuse_creater
(
input0
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
mul0
,
reshape2_qkv_out
,
scale
,
scale_out
);
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
eltadd0
,
{
eltadd0
,
...
@@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_pattern
();
multihead_pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
mul1
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
mul2
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
mul0_out
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
,
Node
*
mul1_out
,
Node
*
softmax_qk
,
Node
*
eltadd0
,
Node
*
eltadd1
,
Node
*
mul2_out
,
Node
*
eltadd2
,
Node
*
matmul_qk
,
Node
*
reshape2_qkv
)
{
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
,
Node
*
softmax_qk
,
Node
*
eltadd0
,
Node
*
eltadd1
,
Node
*
eltadd2
,
Node
*
matmul_qk
,
Node
*
reshape2_qkv
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...
@@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
wq_tensor
->
Resize
(
combined_w_dims
);
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
float
)
*
wq_tensor
->
numel
());
sizeof
(
float
)
*
wq_tensor
->
numel
());
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
...
@@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
size_t
bias_size
=
bq_tensor
->
numel
();
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
memcpy
(
sizeof
(
float
)
*
bias_size
);
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
float
)
*
bias_size
);
sizeof
(
float
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
float
)
*
bq_tensor
->
numel
());
sizeof
(
float
)
*
bq_tensor
->
numel
());
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
...
@@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
...
@@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
...
@@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// patterns, we do not support this kind of fusion, this pass will not take
...
@@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
...
@@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
if
(
is_fc_params_shared
)
{
if
(
is_fc_params_shared
)
{
return
;
return
;
}
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
fuse_creater
(
input0
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
mul0
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
,
softmax_qk
,
mul1
,
eltadd0
,
eltadd1
,
eltadd2
,
matmul_qk
,
reshape2_qkv
);
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
,
softmax_qk
,
eltadd0
,
eltadd1
,
eltadd2
,
matmul_qk
,
reshape2_qkv
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
eltadd1
,
...
@@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
...
@@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
int
fusion_count
=
BuildFusionV2
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusionV2
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
))
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
VLOG
(
3
)
<<
"start varseqlen trt_multihead_matmul_fuse_pass_v2"
;
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
{
if
(
with_interleaved
)
{
VLOG
(
3
)
<<
"start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v2"
;
}
else
{
VLOG
(
3
)
<<
"start varseqlen_trt_multihead_matmul_fuse_pass_v2"
;
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
"Use transformer'varseqlen need "
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"
));
}
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen
trt_multihead_matmul_fuse_pass_v2
"
;
VLOG
(
3
)
<<
"start no_varseqlen
_trt_multihead_matmul_fuse_pass
"
;
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
...
@@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
...
@@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
multihead_pattern
();
multihead_pattern
();
// Create New OpDesc
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
mul1
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
mul2
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
mul0_out
,
Node
*
reshape2_qkv_out
,
Node
*
matmul_qk
)
{
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
matmul_qk
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
matmul_qk
->
Op
()
->
GetAttr
(
"alpha"
));
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
matmul_qk
->
Op
()
->
GetAttr
(
"alpha"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...
@@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
...
@@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
wq_tensor
->
Resize
(
combined_w_dims
);
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
float
)
*
wq_tensor
->
numel
());
sizeof
(
float
)
*
wq_tensor
->
numel
());
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
...
@@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
...
@@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
size_t
bias_size
=
bq_tensor
->
numel
();
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
memcpy
(
sizeof
(
float
)
*
bias_size
);
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
float
)
*
bias_size
);
sizeof
(
float
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
float
)
*
bq_tensor
->
numel
());
sizeof
(
float
)
*
bq_tensor
->
numel
());
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
...
@@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
...
@@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1
,
mul1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1
,
mul1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
...
@@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
...
@@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_pattern
);
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// patterns, we do not support this kind of fusion, this pass will not take
...
@@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
...
@@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
if
(
is_fc_params_shared
)
{
if
(
is_fc_params_shared
)
{
return
;
return
;
}
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
fuse_creater
(
input0
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
mul0
,
reshape2_0
,
reshape2_qkv_out
,
matmul_qk
);
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
matmul_qk
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
eltadd1
,
...
@@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
...
@@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
int
fusion_count
=
BuildFusionV3
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusionV3
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
))
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
VLOG
(
3
)
<<
"start varseqlen trt_multihead_matmul_fuse_pass_v3"
;
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
{
if
(
with_interleaved
)
{
VLOG
(
3
)
<<
"start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v3"
;
}
else
{
VLOG
(
3
)
<<
"start varseqlen_trt_multihead_matmul_fuse_pass_v3"
;
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
"Use transformer'varseqlen need "
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"
));
}
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen
trt_multihead_matmul_fuse_pass_v3
"
;
VLOG
(
3
)
<<
"start no_varseqlen
_trt_multihead_matmul_fuse_pass
"
;
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
...
...
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_pattern
);
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
GET_IR_NODE_FROM_SUBGRAPH
(
fused_pattern
);
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
std
::
unordered_set
<
const
Node
*>
del_node_set
;
...
@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
if
((
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_skip_layernorm_fuse_pass"
;
VLOG
(
3
)
<<
"start varseqlen trt_skip_layernorm_fuse_pass"
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
"trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen"
));
}
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_skip_layernorm_fuse_pass"
;
VLOG
(
3
)
<<
"start no_varseqlen trt_skip_layernorm_fuse_pass"
;
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
178b2440
...
@@ -28,13 +28,21 @@ namespace tensorrt {
...
@@ -28,13 +28,21 @@ namespace tensorrt {
class
PrelnEmbEltwiseLayerNormOpConverter
:
public
OpConverter
{
class
PrelnEmbEltwiseLayerNormOpConverter
:
public
OpConverter
{
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(7000)
#if IS_TRT_VERSION_GE(7000)
VLOG
(
4
)
<<
"convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"
;
VLOG
(
4
)
<<
"convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"
;
if
(
!
(
engine_
->
use_varseqlen
()
&&
engine_
->
with_interleaved
()))
{
auto
pos_id_name
=
engine_
->
tensorrt_transformer_posid
();
auto
mask_id_name
=
engine_
->
tensorrt_transformer_maskid
();
bool
flag_prelayernorm
=
engine_
->
with_interleaved
()
&&
engine_
->
use_varseqlen
()
&&
pos_id_name
!=
""
&&
mask_id_name
!=
""
;
if
(
!
flag_prelayernorm
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"PrelnErnie: If you want to use oss, must be with interleaved"
));
"PrelnErnie: If you want to use varseqlen, must be with interleaved, "
"set pos_id_name, set mask_id_name."
));
}
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
...
@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
}
auto
word_id_name
=
op_desc
.
Input
(
"WordId"
).
front
();
auto
word_id_name
=
op_desc
.
Input
(
"WordId"
).
front
();
auto
pos_id_name
=
op_desc
.
Input
(
"PosId"
).
front
();
engine_
->
Set
(
"ernie_pos_name"
,
new
std
::
string
(
pos_id_name
));
engine_
->
Set
(
"ernie_pos_name"
,
new
std
::
string
(
pos_id_name
));
auto
sent_id_name
=
op_desc
.
Input
(
"SentId"
).
front
();
auto
sent_id_name
=
op_desc
.
Input
(
"SentId"
).
front
();
...
@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
engine_
->
SetITensor
(
"word_id"
,
engine_
->
GetITensor
(
word_id_name
));
engine_
->
SetITensor
(
"pos_id"
,
engine_
->
GetITensor
(
pos_id_name
));
engine_
->
SetITensor
(
"mask_id"
,
engine_
->
GetITensor
(
mask_id_name
));
std
::
vector
<
std
::
string
>
emb_names
;
std
::
vector
<
std
::
string
>
emb_names
;
emb_names
=
emb_names
=
std
::
vector
<
std
::
string
>
{
word_emb_name
,
pos_emb_name
,
sent_emb_name
};
std
::
vector
<
std
::
string
>
{
word_emb_name
,
pos_emb_name
,
sent_emb_name
};
...
@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs
.
push_back
(
emb_data
);
input_embs
.
push_back
(
emb_data
);
emb_sizes
.
push_back
(
emb_size
);
emb_sizes
.
push_back
(
emb_size
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
emb_dims
.
size
(),
2
,
emb_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."
));
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."
));
}
}
...
@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
int
output_int8
=
1
;
int
output_int8
=
1
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input_num
,
3
,
input_num
,
3
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"When using oss and var-len, embedding_eltwise_layernorm op"
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d."
,
"should have 3 inputs only, but got %d."
,
input_num
));
input_num
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
{
"bert_embeddings_layernorm_beta"
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
bias
,
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
{
"output_fp16"
,
&
output_int8
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"output_fp16"
,
&
output_int8
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
...
@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs
.
emplace_back
(
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
pos_id_name
));
// cu_seqlens,
engine_
->
GetITensor
(
pos_id_name
));
// cu_seqlens,
// eval_placeholder_2
// eval_placeholder_2
auto
max_seqlen_tensor
=
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
mask_id_name
);
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
*
shuffle_layer
=
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
max_seqlen_tensor
);
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
max_seqlen_tensor
);
nvinfer1
::
Dims
shape_dim
;
nvinfer1
::
Dims
shape_dim
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录