Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
Paddle
提交
178b2440
P
Paddle
项目概览
wmsofts
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
178b2440
编写于
6月 23, 2022
作者:
W
Wangzheee
提交者:
GitHub
6月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
general_prelayernorm_transformer (#43748)
上级
dbf92d49
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
511 addition
and
199 deletion
+511
-199
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
...amework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
+43
-28
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
...ramework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
+3
-1
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
+18
-12
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+158
-19
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+23
-0
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+224
-118
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
+8
-6
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+34
-15
未找到文件。
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -31,7 +31,8 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
,
bool
is_persist
=
false
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
...
...
@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
if
(
is_persist
)
return
node
->
assert_is_persistable_var
();
return
node
;
}
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
...
...
@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars
(
pattern
,
lookup_table2_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
feed2
=
pattern
->
NewNode
(
feed2_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table2
=
...
...
@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
feed1
->
LinksTo
({
lookup_table1_x
});
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
feed2
->
LinksTo
({
lookup_table2_x
});
lookup_table2
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
...
...
@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars
(
pattern
,
lookup_table1_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table1_out
=
...
...
@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
->
assert_is_op_output
(
"elementwise_add"
);
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
feed1
->
LinksTo
({
lookup_table1_x
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
eltwise_add_in
})
.
LinksTo
({
eltwise_add_out
});
}
...
...
@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_w
,
lookup_table2_w
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2
,
lookup_table2
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_out
,
lookup_table2_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_out
,
lookup_table2_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
start_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
...
...
@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
start_pattern_out_node
.
push_back
(
eltwise_add_out
);
std
::
unordered_set
<
Node
*>
rm_nodes
;
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2
,
lookup_table1_out
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2
,
lookup_table1_out
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
start_pattern_remove_nodes
.
push_back
(
rm_nodes
);
};
gpd
(
graph
,
handler
);
...
...
@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_x
,
lookup_table1_x
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_w
,
lookup_table1_w
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_in
,
eltwise_add_in
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
second_pattern
);
...
...
@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
skip_layernorm_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"Pass(PrelnSkipLayerNorm) in op compat failed."
;
return
;
...
...
@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs
.
push_back
(
inner_pattern_ins
[
js
[
iter
]].
second
->
Name
());
}
OpDesc
new_op_desc
;
OpDesc
new_op_desc
(
end_patter_layernorms
[
0
]
->
Op
()
->
Block
())
;
new_op_desc
.
SetType
(
"fused_preln_embedding_eltwise_layernorm"
);
new_op_desc
.
SetInput
(
"Ids"
,
ids
);
new_op_desc
.
SetInput
(
"Embs"
,
embs
);
...
...
@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
VLOG
(
4
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
with_dynamic_shape
))
{
VLOG
(
3
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, set pos_id, set mask_id, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"please reconfig."
;
return
;
}
int
fusion_count
=
PrelnEmbeddingEltwiseLayerNormFusePass
::
BuildFusion
(
graph
,
name_scope_
);
if
(
fusion_count
>
0
)
{
...
...
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h
浏览文件 @
178b2440
...
...
@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding2_eltwise1"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
feed2
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
...
...
@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"Prelnembedding1_eltwise1"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1
);
...
...
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
VLOG
(
4
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_varseqlen, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. "
;
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
)
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
with_dynamic_shape
))
{
VLOG
(
3
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"with_interleaved"
"use_varseqlen, preln_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass"
"set pos_id, set mask_id, with_dynamic_shape. Stop this pass, "
"please "
"reconfig."
;
return
;
}
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
...
...
@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an PrelnSkipLayerNorm op node
OpDesc
new_desc
;
OpDesc
new_desc
(
elementwise
->
Op
()
->
Block
())
;
new_desc
.
SetType
(
"preln_skip_layernorm"
);
// inputs
...
...
@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
浏览文件 @
178b2440
...
...
@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() {
emb_elt_layernorm_op
->
LinksTo
({
emb_elt_layernorm_out
});
}
void
PrelnEmbEltwiseLayernorm
::
operator
()()
{
// Create nodes for fused_preln_embedding_eltwise_layernorm.
auto
*
preln_emb_elt_layernorm_op
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_op_repr
())
->
assert_is_op
(
"fused_preln_embedding_eltwise_layernorm"
);
auto
*
preln_emb_elt_layernorm_out_0
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_out_0_repr
())
->
assert_is_op_output
(
"fused_preln_embedding_eltwise_layernorm"
,
"Out_0"
);
auto
*
preln_emb_elt_layernorm_out_1
=
pattern
->
NewNode
(
preln_emb_elt_layernorm_out_1_repr
())
->
assert_is_op_output
(
"fused_preln_embedding_eltwise_layernorm"
,
"Out_1"
);
// Add links for fused_preln_embedding_eltwise_layernorm op.
preln_emb_elt_layernorm_op
->
LinksTo
(
{
preln_emb_elt_layernorm_out_0
,
preln_emb_elt_layernorm_out_1
});
}
void
SkipLayernorm
::
operator
()()
{
// Create nodes for skip_layernorm.
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
...
...
@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() {
.
LinksTo
({
skip_layernorm_out
});
}
void
PrelnSkipLayernorm
::
operator
()()
{
// Create nodes for preln_skip_layernorm.
auto
*
preln_skip_layernorm_x
=
pattern
->
NewNode
(
preln_skip_layernorm_x_repr
())
->
assert_is_op_input
(
"preln_skip_layernorm"
,
"X"
);
auto
*
preln_skip_layernorm_y
=
pattern
->
NewNode
(
preln_skip_layernorm_y_repr
())
->
assert_is_op_input
(
"preln_skip_layernorm"
,
"Y"
);
auto
*
preln_skip_layernorm_op
=
pattern
->
NewNode
(
preln_skip_layernorm_op_repr
())
->
assert_is_op
(
"preln_skip_layernorm"
);
auto
*
preln_skip_layernorm_out_0
=
pattern
->
NewNode
(
preln_skip_layernorm_out_0_repr
())
->
assert_is_op_output
(
"preln_skip_layernorm"
,
"Out_0"
);
auto
*
preln_skip_layernorm_out_1
=
pattern
->
NewNode
(
preln_skip_layernorm_out_1_repr
())
->
assert_is_op_output
(
"preln_skip_layernorm"
,
"Out_1"
);
// Add links for preln_skip_layernorm op.
preln_skip_layernorm_op
->
LinksFrom
({
preln_skip_layernorm_x
,
preln_skip_layernorm_y
})
.
LinksTo
({
preln_skip_layernorm_out_0
,
preln_skip_layernorm_out_1
});
}
void
MultiheadMatmul
::
operator
()()
{
// Create nodes for multihead_matmul.
auto
*
multihead_matmul_input
=
...
...
@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen remove_padding_recover_padding_pass"
;
}
else
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass check failed"
;
return
;
}
...
...
@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
// set out_threshold for int8
if
(
op_node
->
Op
()
->
HasAttr
(
"
out_threshold
"
))
{
if
(
op_node
->
Op
()
->
HasAttr
(
"
Input_scale
"
))
{
remove_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
op_node
->
Op
()
->
GetAttr
(
"Input_scale"
));
}
else
{
VLOG
(
3
)
<<
"remove_padding_op has not out_threshold, because next op has "
"not Input_scale."
;
}
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
...
...
@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
if
(
op_node
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_0_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_0_threshold"
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_1_threshold"
))
{
recover_padding
.
SetAttr
(
"out_threshold"
,
op_node
->
Op
()
->
GetAttr
(
"out_1_threshold"
));
}
else
{
VLOG
(
3
)
<<
"recover_padding_op has not out_threshold, because previous "
"op has not out_*_threshold."
;
}
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
...
...
@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_embedding_eltwise_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
fused_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
fused_embedding_eltwise_layernorm
);
insert_recover_padding_op
(
emb_elt_layernorm_op
,
emb_elt_layernorm_out
);
...
...
@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
multihead_matmul_input_shape
=
multihead_matmul_input
->
Var
()
->
GetShape
();
...
...
@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
skip_layernorm_x
->
Var
()
->
GetShape
();
...
...
@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
};
gpd4
(
graph
,
handler4
);
GraphPatternDetector
gpd5
;
patterns
::
PrelnEmbEltwiseLayernorm
fused_preln_embedding_eltwise_layernorm
(
gpd5
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fused_preln_embedding_eltwise_layernorm
();
auto
handler5
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_preln_embedding_eltwise_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_op
,
fused_preln_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_out_0
,
preln_emb_elt_layernorm_out_0
,
fused_preln_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_emb_elt_layernorm_out_1
,
preln_emb_elt_layernorm_out_1
,
fused_preln_embedding_eltwise_layernorm
);
insert_recover_padding_op
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_out_0
);
insert_recover_padding_op
(
preln_emb_elt_layernorm_op
,
preln_emb_elt_layernorm_out_1
);
found_subgraph_count
++
;
};
gpd5
(
graph
,
handler5
);
GraphPatternDetector
gpd6
;
patterns
::
PrelnSkipLayernorm
preln_skip_layernorm
(
gpd6
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
preln_skip_layernorm
();
auto
handler6
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"preln_skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_x
,
preln_skip_layernorm_x
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_y
,
preln_skip_layernorm_y
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_op
,
preln_skip_layernorm_op
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_out_0
,
preln_skip_layernorm_out_0
,
preln_skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
preln_skip_layernorm_out_1
,
preln_skip_layernorm_out_1
,
preln_skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
preln_skip_layernorm_x
->
Var
()
->
GetShape
();
if
(
skip_layernorm_x_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
for
(
size_t
i
=
0
;
i
<
skip_layernorm_x_shape
.
size
();
++
i
)
{
if
(
skip_layernorm_x_shape
[
i
]
!=
multihead_matmul_input_shape
[
i
])
{
check_flag
=
false
;
}
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
insert_remove_padding_op
(
preln_skip_layernorm_x
,
preln_skip_layernorm_op
);
insert_remove_padding_op
(
preln_skip_layernorm_y
,
preln_skip_layernorm_op
);
insert_recover_padding_op
(
preln_skip_layernorm_op
,
preln_skip_layernorm_out_0
);
insert_recover_padding_op
(
preln_skip_layernorm_op
,
preln_skip_layernorm_out_1
);
found_subgraph_count
++
;
};
gpd6
(
graph
,
handler6
);
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
浏览文件 @
178b2440
...
...
@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase {
PATTERN_DECL_NODE
(
emb_elt_layernorm_out
);
};
struct
PrelnEmbEltwiseLayernorm
:
public
PatternBase
{
PrelnEmbEltwiseLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_emb_elt_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_op
);
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_out_0
);
PATTERN_DECL_NODE
(
preln_emb_elt_layernorm_out_1
);
};
struct
SkipLayernorm
:
public
PatternBase
{
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
...
...
@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase {
PATTERN_DECL_NODE
(
skip_layernorm_out
);
};
struct
PrelnSkipLayernorm
:
public
PatternBase
{
PrelnSkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_skip_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
preln_skip_layernorm_x
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_y
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_op
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_out_0
);
PATTERN_DECL_NODE
(
preln_skip_layernorm_out_1
);
};
struct
MultiheadMatmul
:
public
PatternBase
{
MultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
...
...
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
浏览文件 @
178b2440
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
浏览文件 @
178b2440
...
...
@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
...
...
@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
if
((
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
||
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
))
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_skip_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
"trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_skip_layernorm_fuse_pass"
;
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
178b2440
...
...
@@ -28,13 +28,21 @@ namespace tensorrt {
class
PrelnEmbEltwiseLayerNormOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(7000)
VLOG
(
4
)
<<
"convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"
;
if
(
!
(
engine_
->
use_varseqlen
()
&&
engine_
->
with_interleaved
()))
{
auto
pos_id_name
=
engine_
->
tensorrt_transformer_posid
();
auto
mask_id_name
=
engine_
->
tensorrt_transformer_maskid
();
bool
flag_prelayernorm
=
engine_
->
with_interleaved
()
&&
engine_
->
use_varseqlen
()
&&
pos_id_name
!=
""
&&
mask_id_name
!=
""
;
if
(
!
flag_prelayernorm
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"PrelnErnie: If you want to use oss, must be with interleaved"
));
"PrelnErnie: If you want to use varseqlen, must be with interleaved, "
"set pos_id_name, set mask_id_name."
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
...
...
@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
auto
word_id_name
=
op_desc
.
Input
(
"WordId"
).
front
();
auto
pos_id_name
=
op_desc
.
Input
(
"PosId"
).
front
();
engine_
->
Set
(
"ernie_pos_name"
,
new
std
::
string
(
pos_id_name
));
auto
sent_id_name
=
op_desc
.
Input
(
"SentId"
).
front
();
...
...
@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
engine_
->
SetITensor
(
"word_id"
,
engine_
->
GetITensor
(
word_id_name
));
engine_
->
SetITensor
(
"pos_id"
,
engine_
->
GetITensor
(
pos_id_name
));
engine_
->
SetITensor
(
"mask_id"
,
engine_
->
GetITensor
(
mask_id_name
));
std
::
vector
<
std
::
string
>
emb_names
;
emb_names
=
std
::
vector
<
std
::
string
>
{
word_emb_name
,
pos_emb_name
,
sent_emb_name
};
...
...
@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs
.
push_back
(
emb_data
);
emb_sizes
.
push_back
(
emb_size
);
PADDLE_ENFORCE_EQ
(
emb_dims
.
size
(),
2
,
emb_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."
));
}
...
...
@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
int
output_int8
=
1
;
PADDLE_ENFORCE_EQ
(
input_num
,
3
,
input_num
,
3
,
platform
::
errors
::
InvalidArgument
(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d."
,
input_num
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
{
"output_fp16"
,
&
output_int8
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
...
...
@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
pos_id_name
));
// cu_seqlens,
// eval_placeholder_2
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
mask_id_name
);
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
max_seqlen_tensor
);
nvinfer1
::
Dims
shape_dim
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录