Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
178b2440
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
178b2440
编写于
6月 23, 2022
作者:
W
Wangzheee
提交者:
GitHub
6月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
general_prelayernorm_transformer (#43748)
上级
dbf92d49
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
511 addition
and
199 deletion
+511
-199
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
...amework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
+43
-28
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
...ramework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
+3
-1
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
+18
-12
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+158
-19
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+23
-0
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+224
-118
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
+8
-6
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+34
-15
未找到文件。
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -31,7 +31,8 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
,
bool
is_persist
=
false
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
...
...
@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
if
(
is_persist
)
return
node
->
assert_is_persistable_var
();
return
node
;
}
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
...
...
@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars
(
pattern
,
lookup_table2_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
feed2
=
pattern
->
NewNode
(
feed2_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table2
=
...
...
@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
feed1
->
LinksTo
({
lookup_table1_x
});
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
feed2
->
LinksTo
({
lookup_table2_x
});
lookup_table2
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
...
...
@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars
(
pattern
,
lookup_table1_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table1_out
=
...
...
@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
->
assert_is_op_output
(
"elementwise_add"
);
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
feed1
->
LinksTo
({
lookup_table1_x
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
eltwise_add_in
})
.
LinksTo
({
eltwise_add_out
});
}
...
...
@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_w
,
lookup_table2_w
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2
,
lookup_table2
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_out
,
lookup_table2_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_out
,
lookup_table2_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
start_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
...
...
@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
start_pattern_out_node
.
push_back
(
eltwise_add_out
);
std
::
unordered_set
<
Node
*>
rm_nodes
;
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2
,
lookup_table1_out
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2
,
lookup_table1_out
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
start_pattern_remove_nodes
.
push_back
(
rm_nodes
);
};
gpd
(
graph
,
handler
);
...
...
@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_x
,
lookup_table1_x
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_w
,
lookup_table1_w
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_in
,
eltwise_add_in
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
second_pattern
);
...
...
@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
skip_layernorm_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"Pass(PrelnSkipLayerNorm) in op compat failed."
;
return
;
...
...
@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs
.
push_back
(
inner_pattern_ins
[
js
[
iter
]].
second
->
Name
());
}
OpDesc
new_op_desc
;
OpDesc
new_op_desc
(
end_patter_layernorms
[
0
]
->
Op
()
->
Block
())
;
new_op_desc
.
SetType
(
"fused_preln_embedding_eltwise_layernorm"
);
new_op_desc
.
SetInput
(
"Ids"
,
ids
);
new_op_desc
.
SetInput
(
"Embs"
,
embs
);
...
...
@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
VLOG
(
4
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
with_dynamic_shape
))
{
VLOG
(
3
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, set pos_id, set mask_id, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"please reconfig."
;
return
;
}
int
fusion_count
=
PrelnEmbeddingEltwiseLayerNormFusePass
::
BuildFusion
(
graph
,
name_scope_
);
if
(
fusion_count
>
0
)
{
...
...
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
浏览文件 @
178b2440
...
...
@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding2_eltwise1"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
feed2
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
...
...
@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding1_eltwise1"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1
);
...
...
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
VLOG
(
4
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_varseqlen, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. "
;
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
)
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
with_dynamic_shape
))
{
VLOG
(
3
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"with_interleaved"
"use_varseqlen, preln_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass"
"set pos_id, set mask_id, with_dynamic_shape. Stop this pass, "
"please "
"reconfig."
;
return
;
}
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
...
...
@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an PrelnSkipLayerNorm op node
OpDesc
new_desc
;
OpDesc
new_desc
(
elementwise
->
Op
()
->
Block
())
;
new_desc
.
SetType
(
"preln_skip_layernorm"
);
// inputs
...
...
@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
浏览文件 @
178b2440
...
...
@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() {
emb_elt_layernorm_op
->
LinksTo
({
emb_elt_layernorm_out
});
}
void
PrelnEmbEltwiseLayernorm
::
operator
()()
{
// Create nodes for fused_preln_embedding_eltwise_layernorm.
auto
*
preln_emb_elt_layernorm_op
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_op_repr
())
->
assert_is_op
(
"fused_preln_embedding_eltwise_layernorm"
);
auto
*
preln_emb_elt_layernorm_out_0
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_out_0_repr
())
->
assert_is_op_output
(
"fused_preln_embedding_eltwise_layernorm"
,
"Out_0"
);
auto
*
preln_emb_elt_layernorm_out_1
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_out_1_repr
())
->
assert_is_op_output
(
"fused_preln_embedding_eltwise_layernorm"
,
"Out_1"
);
// Add links for fused_preln_embedding_eltwise_layernorm op.
preln_emb_elt_layernorm_op
->
LinksTo
(
{
preln_emb_elt_layernorm_out_0
,
preln_emb_elt_layernorm_out_1
});
}
void
SkipLayernorm
::
operator
()()
{
// Create nodes for skip_layernorm.
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
...
...
@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() {
.
LinksTo
({
skip_layernorm_out
});
}
void
PrelnSkipLayernorm
::
operator
()()
{
// Create nodes for preln_skip_layernorm.
auto
*
preln_skip_layernorm_x
=
pattern
->
NewNode
(
preln_skip_layernorm_x_repr
())
->
assert_is_op_input
(
"preln_skip_layernorm"
,
"X"
);
auto
*
preln_skip_layernorm_y
=
pattern
->
NewNode
(
preln_skip_layernorm_y_repr
())
->
assert_is_op_input
(
"preln_skip_layernorm"
,
"Y"
);
auto
*
preln_skip_layernorm_op
=
pattern
->
NewNode
(
preln_skip_layernorm_op_repr
())
->
assert_is_op
(
"preln_skip_layernorm"
);
auto
*
preln_skip_layernorm_out_0
=
pattern
->
NewNode
(
preln_skip_layernorm_out_0_repr
())
->
assert_is_op_output
(
"preln_skip_layernorm"
,
"Out_0"
);
auto
*
preln_skip_layernorm_out_1
=
pattern
->
NewNode
(
preln_skip_layernorm_out_1_repr
())
->
assert_is_op_output
(
"preln_skip_layernorm"
,
"Out_1"
);
// Add links for preln_skip_layernorm op.
preln_skip_layernorm_op
->
LinksFrom
({
preln_skip_layernorm_x
,
preln_skip_layernorm_y
})
.
LinksTo
({
preln_skip_layernorm_out_0
,
preln_skip_layernorm_out_1
});
}
void
MultiheadMatmul
::
operator
()()
{
// Create nodes for multihead_matmul.
auto
*
multihead_matmul_input
=
...
...
@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen remove_padding_recover_padding_pass"
;
}
else
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass check failed"
;
return
;
}
...
...
@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
// set out_threshold for int8
if
(
op_node
->
Op
()
->
HasAttr
(
"
out_threshold
"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"
Input_scale
"
))
{
remove_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
op_node
->
Op
()
->
GetAttr
(
"Input_scale"
));
}
else
{
VLOG
(
3
)
<<
"remove_padding_op has not out_threshold, because next op has "
"not Input_scale."
;
}
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
...
...
@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
if
(
op_node
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_0_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_0_threshold"
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_1_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_1_threshold"
));
}
else
{
VLOG
(
3
)
<<
"recover_padding_op has not out_threshold, because previous "
"op has not out_*_threshold."
;
}
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
...
...
@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_embedding_eltwise_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
fused_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
fused_embedding_eltwise_layernorm
);
insert_recover_padding_op
(
emb_elt_layernorm_op
,
emb_elt_layernorm_out
);
...
...
@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
multihead_matmul_input_shape
=
multihead_matmul_input
->
Var
()
->
GetShape
();
...
...
@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
skip_layernorm_x
->
Var
()
->
GetShape
();
...
...
@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
};
gpd4
(
graph
,
handler4
);
GraphPatternDetector
gpd5
;
patterns
::
PrelnEmbEltwiseLayernorm
fused_preln_embedding_eltwise_layernorm
(
gpd5
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fused_preln_embedding_eltwise_layernorm
();
auto
handler5
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_preln_embedding_eltwise_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_op
,
fused_preln_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_out_0
,
preln_emb_elt_layernorm_out_0
,
fused_preln_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_out_1
,
preln_emb_elt_layernorm_out_1
,
fused_preln_embedding_eltwise_layernorm
);
insert_recover_padding_op
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_out_0
);
insert_recover_padding_op
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_out_1
);
found_subgraph_count
++
;
};
gpd5
(
graph
,
handler5
);
GraphPatternDetector
gpd6
;
patterns
::
PrelnSkipLayernorm
preln_skip_layernorm
(
gpd6
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
preln_skip_layernorm
();
auto
handler6
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"preln_skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_x
,
preln_skip_layernorm_x
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_y
,
preln_skip_layernorm_y
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_op
,
preln_skip_layernorm_op
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_out_0
,
preln_skip_layernorm_out_0
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_out_1
,
preln_skip_layernorm_out_1
,
preln_skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
preln_skip_layernorm_x
->
Var
()
->
GetShape
();
if
(
skip_layernorm_x_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
for
(
size_t
i
=
0
;
i
<
skip_layernorm_x_shape
.
size
();
++
i
)
{
if
(
skip_layernorm_x_shape
[
i
]
!=
multihead_matmul_input_shape
[
i
])
{
check_flag
=
false
;
}
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
insert_remove_padding_op
(
preln_skip_layernorm_x
,
preln_skip_layernorm_op
);
insert_remove_padding_op
(
preln_skip_layernorm_y
,
preln_skip_layernorm_op
);
insert_recover_padding_op
(
preln_skip_layernorm_op
,
preln_skip_layernorm_out_0
);
insert_recover_padding_op
(
preln_skip_layernorm_op
,
preln_skip_layernorm_out_1
);
found_subgraph_count
++
;
};
gpd6
(
graph
,
handler6
);
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
浏览文件 @
178b2440
...
...
@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase {
PATTERN_DECL_NODE
(
emb_elt_layernorm_out
);
};
struct
PrelnEmbEltwiseLayernorm
:
public
PatternBase
{
PrelnEmbEltwiseLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_emb_elt_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_op
);
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_out_0
);
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_out_1
);
};
struct
SkipLayernorm
:
public
PatternBase
{
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
...
...
@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase {
PATTERN_DECL_NODE
(
skip_layernorm_out
);
};
struct
PrelnSkipLayernorm
:
public
PatternBase
{
PrelnSkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_skip_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
preln_skip_layernorm_x
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_y
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_op
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_out_0
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_out_1
);
};
struct
MultiheadMatmul
:
public
PatternBase
{
MultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
...
...
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
multihead_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
...
...
@@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
...
...
@@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
...
...
@@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
eltadd0
,
...
...
@@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
,
Node
*
softmax_qk
,
Node
*
eltadd0
,
Node
*
eltadd1
,
Node
*
eltadd2
,
Node
*
matmul_qk
,
Node
*
reshape2_qkv
)
{
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
,
Node
*
softmax_qk
,
Node
*
eltadd0
,
Node
*
eltadd1
,
Node
*
eltadd2
,
Node
*
matmul_qk
,
Node
*
reshape2_qkv
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...
...
@@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
float
)
*
wq_tensor
->
numel
());
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
...
...
@@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
float
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
float
)
*
bq_tensor
->
numel
());
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
...
...
@@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
...
...
@@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
...
...
@@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
...
...
@@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
if
(
is_fc_params_shared
)
{
return
;
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
,
softmax_qk
,
eltadd0
,
eltadd1
,
eltadd2
,
matmul_qk
,
reshape2_qkv
);
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
,
softmax_qk
,
eltadd0
,
eltadd1
,
eltadd2
,
matmul_qk
,
reshape2_qkv
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
...
...
@@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
int
fusion_count
=
BuildFusionV2
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_multihead_matmul_fuse_pass_v2"
;
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
{
if
(
with_interleaved
)
{
VLOG
(
3
)
<<
"start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v2"
;
}
else
{
VLOG
(
3
)
<<
"start varseqlen_trt_multihead_matmul_fuse_pass_v2"
;
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen
trt_multihead_matmul_fuse_pass_v2
"
;
VLOG
(
3
)
<<
"start no_varseqlen
_trt_multihead_matmul_fuse_pass
"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
...
...
@@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
multihead_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
matmul_qk
)
{
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
matmul_qk
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
matmul_qk
->
Op
()
->
GetAttr
(
"alpha"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...
...
@@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
float
)
*
wq_tensor
->
numel
());
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
...
...
@@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
float
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
float
)
*
bq_tensor
->
numel
());
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
...
...
@@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1
,
mul1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
...
...
@@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
...
...
@@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
if
(
is_fc_params_shared
)
{
return
;
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
matmul_qk
);
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
matmul_qk
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
...
...
@@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
int
fusion_count
=
BuildFusionV3
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_multihead_matmul_fuse_pass_v3"
;
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
{
if
(
with_interleaved
)
{
VLOG
(
3
)
<<
"start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v3"
;
}
else
{
VLOG
(
3
)
<<
"start varseqlen_trt_multihead_matmul_fuse_pass_v3"
;
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen
trt_multihead_matmul_fuse_pass_v3
"
;
VLOG
(
3
)
<<
"start no_varseqlen
_trt_multihead_matmul_fuse_pass
"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
...
...
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
...
...
@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
if
((
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_skip_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
"trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_skip_layernorm_fuse_pass"
;
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
178b2440
...
...
@@ -28,13 +28,21 @@ namespace tensorrt {
class
PrelnEmbEltwiseLayerNormOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(7000)
VLOG
(
4
)
<<
"convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"
;
if
(
!
(
engine_
->
use_varseqlen
()
&&
engine_
->
with_interleaved
()))
{
auto
pos_id_name
=
engine_
->
tensorrt_transformer_posid
();
auto
mask_id_name
=
engine_
->
tensorrt_transformer_maskid
();
bool
flag_prelayernorm
=
engine_
->
with_interleaved
()
&&
engine_
->
use_varseqlen
()
&&
pos_id_name
!=
""
&&
mask_id_name
!=
""
;
if
(
!
flag_prelayernorm
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"PrelnErnie: If you want to use oss, must be with interleaved"
));
"PrelnErnie: If you want to use varseqlen, must be with interleaved, "
"set pos_id_name, set mask_id_name."
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
...
...
@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
auto
word_id_name
=
op_desc
.
Input
(
"WordId"
).
front
();
auto
pos_id_name
=
op_desc
.
Input
(
"PosId"
).
front
();
engine_
->
Set
(
"ernie_pos_name"
,
new
std
::
string
(
pos_id_name
));
auto
sent_id_name
=
op_desc
.
Input
(
"SentId"
).
front
();
...
...
@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
engine_
->
SetITensor
(
"word_id"
,
engine_
->
GetITensor
(
word_id_name
));
engine_
->
SetITensor
(
"pos_id"
,
engine_
->
GetITensor
(
pos_id_name
));
engine_
->
SetITensor
(
"mask_id"
,
engine_
->
GetITensor
(
mask_id_name
));
std
::
vector
<
std
::
string
>
emb_names
;
emb_names
=
std
::
vector
<
std
::
string
>
{
word_emb_name
,
pos_emb_name
,
sent_emb_name
};
...
...
@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs
.
push_back
(
emb_data
);
emb_sizes
.
push_back
(
emb_size
);
PADDLE_ENFORCE_EQ
(
emb_dims
.
size
(),
2
,
emb_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."
));
}
...
...
@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
int
output_int8
=
1
;
PADDLE_ENFORCE_EQ
(
input_num
,
3
,
input_num
,
3
,
platform
::
errors
::
InvalidArgument
(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d."
,
input_num
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
{
"output_fp16"
,
&
output_int8
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
...
...
@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
pos_id_name
));
// cu_seqlens,
// eval_placeholder_2
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
mask_id_name
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
max_seqlen_tensor
);
nvinfer1
::
Dims
shape_dim
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录