Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2810dfea
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2810dfea
编写于
6月 02, 2022
作者:
W
Wangzheee
提交者:
GitHub
6月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-Inference] new general transformer inference support (#43077)
* new general transformer inference support
上级
0cb9dae5
变更
55
隐藏空白更改
内联
并排
Showing
55 changed file
with
4124 addition
and
507 deletion
+4124
-507
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+3
-0
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
...amework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
+5
-3
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
+4
-3
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+161
-43
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+8
-0
paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc
.../fluid/framework/ir/set_transformer_input_convert_pass.cc
+92
-93
paddle/fluid/framework/ir/set_transformer_input_convert_pass.h
...e/fluid/framework/ir/set_transformer_input_convert_pass.h
+19
-24
paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
...framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
+477
-0
paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h
.../framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h
+167
-0
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+1546
-0
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h
+179
-0
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
+232
-0
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h
+87
-0
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+5
-1
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+5
-1
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+8
-2
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+9
-3
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+20
-1
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+5
-3
paddle/fluid/inference/api/paddle_api.h
paddle/fluid/inference/api/paddle_api.h
+6
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+16
-16
paddle/fluid/inference/capi_exp/pd_config.cc
paddle/fluid/inference/capi_exp/pd_config.cc
+3
-3
paddle/fluid/inference/capi_exp/pd_config.h
paddle/fluid/inference/capi_exp/pd_config.h
+1
-1
paddle/fluid/inference/goapi/config.go
paddle/fluid/inference/goapi/config.go
+2
-2
paddle/fluid/inference/goapi/config_test.go
paddle/fluid/inference/goapi/config_test.go
+2
-2
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+4
-0
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+13
-15
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+1
-2
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+11
-19
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc
.../fluid/inference/tensorrt/convert/preln_skip_layernorm.cc
+3
-2
paddle/fluid/inference/tensorrt/convert/recover_padding_op.cc
...le/fluid/inference/tensorrt/convert/recover_padding_op.cc
+76
-0
paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc
paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc
+69
-0
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+6
-3
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+6
-42
paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc
...nference/tensorrt/convert/transformer_input_convert_op.cc
+72
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+17
-4
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+8
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+4
-1
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
...fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
+120
-0
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h
.../fluid/inference/tensorrt/plugin/recover_padding_plugin.h
+133
-0
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
.../fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
+118
-0
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h
...e/fluid/inference/tensorrt/plugin/remove_padding_plugin.h
+133
-0
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
...e/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
+0
-197
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu
...rence/tensorrt/plugin/transformer_input_convert_plugin.cu
+110
-0
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h
...erence/tensorrt/plugin/transformer_input_convert_plugin.h
+134
-0
paddle/fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc
...fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc
+1
-1
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
...fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
+5
-1
paddle/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc
...le/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc
+1
-1
paddle/fluid/inference/utils/table_printer_tester.cc
paddle/fluid/inference/utils/table_printer_tester.cc
+1
-1
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+3
-2
python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py
...fluid/tests/unittests/ir/inference/inference_pass_test.py
+3
-3
python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py
.../fluid/tests/unittests/ir/inference/quant_dequant_test.py
+3
-3
python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py
...sts/unittests/ir/inference/test_trt_multiclass_nms3_op.py
+3
-3
python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py
...ests/unittests/ir/inference/test_trt_multiclass_nms_op.py
+3
-3
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
2810dfea
...
@@ -107,6 +107,9 @@ target_link_libraries(generate_pass pass_desc_proto)
...
@@ -107,6 +107,9 @@ target_link_libraries(generate_pass pass_desc_proto)
if
(
WITH_TENSORRT
)
if
(
WITH_TENSORRT
)
pass_library
(
trt_map_matmul_to_mul_pass inference
)
pass_library
(
trt_map_matmul_to_mul_pass inference
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
trt_multihead_matmul_fuse_pass inference
)
pass_library
(
trt_skip_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_skip_layernorm_fuse_pass inference
)
pass_library
(
preln_skip_layernorm_fuse_pass inference
)
pass_library
(
set_transformer_input_convert_pass inference
)
pass_library
(
set_transformer_input_convert_pass inference
)
...
...
paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
2810dfea
...
@@ -430,13 +430,15 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
...
@@ -430,13 +430,15 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
bool
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
bool
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
bool
use_
oss
=
Get
<
bool
>
(
"use_oss
"
);
bool
use_
varseqlen
=
Get
<
bool
>
(
"use_varseqlen
"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
if
(
!
(
enable_int8
&&
use_oss
&&
with_interleaved
&&
with_dynamic_shape
))
{
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
VLOG
(
4
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
VLOG
(
4
)
<<
"preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
"enable_int8, "
"use_oss, with_interleaved, with_dynamic_shape. Stop this pass, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"please reconfig."
;
"please reconfig."
;
return
;
return
;
}
}
...
...
paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc
浏览文件 @
2810dfea
...
@@ -109,12 +109,13 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -109,12 +109,13 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"preln_skip_layernorm_fuse"
,
graph
);
FusePassBase
::
Init
(
"preln_skip_layernorm_fuse"
,
graph
);
bool
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
bool
enable_int8
=
Get
<
bool
>
(
"enable_int8"
);
bool
use_
oss
=
Get
<
bool
>
(
"use_oss
"
);
bool
use_
varseqlen
=
Get
<
bool
>
(
"use_varseqlen
"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_interleaved
=
Get
<
bool
>
(
"with_interleaved"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
if
(
!
(
enable_int8
&&
use_oss
&&
with_interleaved
&&
with_dynamic_shape
))
{
if
(
!
(
enable_int8
&&
use_varseqlen
&&
with_interleaved
&&
with_dynamic_shape
))
{
VLOG
(
4
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
VLOG
(
4
)
<<
"preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_
oss
, "
"use_
varseqlen
, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. "
;
"reconfig. "
;
return
;
return
;
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
浏览文件 @
2810dfea
...
@@ -22,6 +22,19 @@ namespace paddle {
...
@@ -22,6 +22,19 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
void
EmbEltwiseLayernorm
::
operator
()()
{
// Create nodes for fused_embedding_eltwise_layernorm.
auto
*
emb_elt_layernorm_op
=
pattern
->
NewNode
(
emb_elt_layernorm_op_repr
())
->
assert_is_op
(
"fused_embedding_eltwise_layernorm"
);
auto
*
emb_elt_layernorm_out
=
pattern
->
NewNode
(
emb_elt_layernorm_out_repr
())
->
assert_is_op_output
(
"fused_embedding_eltwise_layernorm"
,
"Out"
);
// Add links for fused_embedding_eltwise_layernorm op.
emb_elt_layernorm_op
->
LinksTo
({
emb_elt_layernorm_out
});
}
void
SkipLayernorm
::
operator
()()
{
void
SkipLayernorm
::
operator
()()
{
// Create nodes for skip_layernorm.
// Create nodes for skip_layernorm.
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
...
@@ -59,16 +72,12 @@ void Fc::operator()() {
...
@@ -59,16 +72,12 @@ void Fc::operator()() {
auto
*
fc_input
=
auto
*
fc_input
=
pattern
->
NewNode
(
fc_input_repr
())
->
assert_is_op_input
(
"fc"
,
"Input"
);
pattern
->
NewNode
(
fc_input_repr
())
->
assert_is_op_input
(
"fc"
,
"Input"
);
auto
*
fc_op
=
pattern
->
NewNode
(
fc_op_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_op
=
pattern
->
NewNode
(
fc_op_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
fc_op
->
LinksFrom
({
fc_input
});
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
);
// Add links for fc op.
fc_op
->
LinksFrom
({
fc_input
}).
LinksTo
({
fc_out
});
}
}
void
Activation
::
operator
()()
{
void
Activation
::
operator
()()
{
// Create nodes for activation.
// Create nodes for activation.
std
::
unordered_set
<
std
::
string
>
activation_ops
{
"relu"
,
"sigmoid"
,
"
tanh
"
};
std
::
unordered_set
<
std
::
string
>
activation_ops
{
"relu"
,
"sigmoid"
,
"
gelu
"
};
auto
*
activation_input
=
pattern
->
NewNode
(
activation_input_repr
())
auto
*
activation_input
=
pattern
->
NewNode
(
activation_input_repr
())
->
assert_is_ops_input
(
activation_ops
);
->
assert_is_ops_input
(
activation_ops
);
auto
*
activation_op
=
auto
*
activation_op
=
...
@@ -82,6 +91,18 @@ void Activation::operator()() {
...
@@ -82,6 +91,18 @@ void Activation::operator()() {
}
// namespace patterns
}
// namespace patterns
void
RemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
RemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
&&
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen remove_padding_recover_padding_pass"
;
}
else
{
return
;
}
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
...
@@ -91,14 +112,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -91,14 +112,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// Create an remove_padding op node
// Create an remove_padding op node
auto
insert_remove_padding_op
=
[
&
](
Node
*
input_node
,
Node
*
op_node
)
{
auto
insert_remove_padding_op
=
[
&
](
Node
*
input_node
,
Node
*
op_node
)
{
// create op, var in graph
// create op, var in graph
OpDesc
remove_padding
;
OpDesc
remove_padding
(
op_node
->
Op
()
->
Block
())
;
std
::
string
remove_padding_out_name
=
std
::
string
remove_padding_out_name
=
input_node
->
Name
()
+
".remove_padding"
;
input_node
->
Name
()
+
".remove_padding"
;
auto
*
remove_padding_out
=
VarDesc
remove_padding_out
(
remove_padding_out_name
);
op_node
->
Op
()
->
Block
()
->
Var
(
remove_padding_out_name
);
remove_padding_out
.
SetDataType
(
input_node
->
Var
()
->
GetDataType
());
remove_padding_out
->
SetDataType
(
input_node
->
Var
()
->
GetDataType
());
remove_padding_out
.
SetShape
(
input_node
->
Var
()
->
GetShape
());
remove_padding_out
->
SetShape
(
input_node
->
Var
()
->
GetShape
());
remove_padding_out
.
SetPersistable
(
false
);
remove_padding_out
->
SetPersistable
(
false
);
// remove_padding_op
// remove_padding_op
remove_padding
.
SetType
(
"remove_padding"
);
remove_padding
.
SetType
(
"remove_padding"
);
...
@@ -110,7 +131,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -110,7 +131,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
auto
remove_padding_out_node
=
graph
->
CreateVarNode
(
&
remove_padding_out
);
auto
remove_padding_out_node
=
graph
->
CreateVarNode
(
remove_padding_out
);
// replace link
// replace link
for
(
size_t
i
=
0
;
i
<
input_node
->
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_node
->
outputs
.
size
();
++
i
)
{
...
@@ -145,13 +166,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -145,13 +166,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// create an remove_padding op node
// create an remove_padding op node
auto
insert_recover_padding_op
=
[
&
](
Node
*
op_node
,
Node
*
out_node
)
{
auto
insert_recover_padding_op
=
[
&
](
Node
*
op_node
,
Node
*
out_node
)
{
// create op, var in graph
// create op, var in graph
OpDesc
recover_padding
;
OpDesc
recover_padding
(
op_node
->
Op
()
->
Block
())
;
std
::
string
recover_padding_input_name
=
std
::
string
recover_padding_input_name
=
out_node
->
Name
()
+
".recover_padding"
;
out_node
->
Name
()
+
".recover_padding"
;
VarDesc
recover_padding_input
(
recover_padding_input_name
);
auto
*
recover_padding_input
=
recover_padding_input
.
SetDataType
(
out_node
->
Var
()
->
GetDataType
());
op_node
->
Op
()
->
Block
()
->
Var
(
recover_padding_input_name
);
recover_padding_input
.
SetShape
(
out_node
->
Var
()
->
GetShape
());
recover_padding_input
->
SetDataType
(
out_node
->
Var
()
->
GetDataType
());
recover_padding_input
.
SetPersistable
(
false
);
recover_padding_input
->
SetShape
(
out_node
->
Var
()
->
GetShape
());
recover_padding_input
->
SetPersistable
(
false
);
// recover_padding_op
// recover_padding_op
recover_padding
.
SetType
(
"recover_padding"
);
recover_padding
.
SetType
(
"recover_padding"
);
...
@@ -164,7 +186,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -164,7 +186,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
auto
recover_padding_input_node
=
auto
recover_padding_input_node
=
graph
->
CreateVarNode
(
&
recover_padding_input
);
graph
->
CreateVarNode
(
recover_padding_input
);
// replace link
// replace link
for
(
size_t
i
=
0
;
i
<
op_node
->
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op_node
->
outputs
.
size
();
++
i
)
{
...
@@ -195,39 +217,36 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -195,39 +217,36 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
op_node
->
Op
()
->
RenameOutput
(
out_node
->
Name
(),
recover_padding_input_name
);
op_node
->
Op
()
->
RenameOutput
(
out_node
->
Name
(),
recover_padding_input_name
);
};
};
GraphPatternDetector
gpd1
;
bool
check_flag
=
true
;
patterns
::
SkipLayernorm
skip_layernorm
(
gpd1
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
skip_layernorm
();
auto
handler1
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
GraphPatternDetector
gpd0
;
patterns
::
EmbEltwiseLayernorm
fused_embedding_eltwise_layernorm
(
gpd0
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fused_embedding_eltwise_layernorm
();
auto
handler0
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"
skip
_layernorm"
;
"
fused_embedding_eltwise
_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_op
,
emb_elt_layernorm_op
,
skip_layernorm
);
fused_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
GET_IR_NODE_FROM_SUBGRAPH
(
emb_elt_layernorm_out
,
emb_elt_layernorm_out
,
skip_layernorm
);
fused_embedding_eltwise_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
insert_remove_padding_op
(
skip_layernorm_x
,
skip_layernorm_op
);
insert_recover_padding_op
(
emb_elt_layernorm_op
,
emb_elt_layernorm_out
);
insert_remove_padding_op
(
skip_layernorm_y
,
skip_layernorm_op
);
insert_recover_padding_op
(
skip_layernorm_op
,
skip_layernorm_out
);
found_subgraph_count
++
;
found_subgraph_count
++
;
};
};
gpd
1
(
graph
,
handler1
);
gpd
0
(
graph
,
handler0
);
GraphPatternDetector
gpd
2
;
GraphPatternDetector
gpd
1
;
patterns
::
MultiheadMatmul
multihead_matmul
(
patterns
::
MultiheadMatmul
multihead_matmul
(
gpd
2
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
gpd
1
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
multihead_matmul
();
multihead_matmul
();
auto
handler2
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
std
::
vector
<
int64_t
>
multihead_matmul_input_shape
;
auto
handler1
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"
;
"multihead_matmul"
;
...
@@ -239,11 +258,57 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -239,11 +258,57 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
multihead_matmul
);
multihead_matmul_input_shape
=
multihead_matmul_input
->
Var
()
->
GetShape
();
insert_remove_padding_op
(
multihead_matmul_input
,
multihead_matmul_op
);
insert_remove_padding_op
(
multihead_matmul_input
,
multihead_matmul_op
);
insert_recover_padding_op
(
multihead_matmul_op
,
multihead_matmul_out
);
insert_recover_padding_op
(
multihead_matmul_op
,
multihead_matmul_out
);
found_subgraph_count
++
;
found_subgraph_count
++
;
};
};
gpd1
(
graph
,
handler1
);
GraphPatternDetector
gpd2
;
patterns
::
SkipLayernorm
skip_layernorm
(
gpd2
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
skip_layernorm
();
auto
handler2
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
std
::
vector
<
int64_t
>
skip_layernorm_x_shape
=
skip_layernorm_x
->
Var
()
->
GetShape
();
if
(
skip_layernorm_x_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
for
(
size_t
i
=
0
;
i
<
skip_layernorm_x_shape
.
size
();
++
i
)
{
if
(
skip_layernorm_x_shape
[
i
]
!=
multihead_matmul_input_shape
[
i
])
{
check_flag
=
false
;
}
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
insert_remove_padding_op
(
skip_layernorm_x
,
skip_layernorm_op
);
insert_remove_padding_op
(
skip_layernorm_y
,
skip_layernorm_op
);
insert_recover_padding_op
(
skip_layernorm_op
,
skip_layernorm_out
);
found_subgraph_count
++
;
};
gpd2
(
graph
,
handler2
);
gpd2
(
graph
,
handler2
);
GraphPatternDetector
gpd3
;
GraphPatternDetector
gpd3
;
...
@@ -257,11 +322,39 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -257,11 +322,39 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
fc_input
,
fc_input
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_input
,
fc_input
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_op
,
fc_op
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_op
,
fc_op
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc
);
insert_remove_padding_op
(
fc_input
,
fc_op
);
std
::
vector
<
int64_t
>
fc_input_shape
=
fc_input
->
Var
()
->
GetShape
();
insert_recover_padding_op
(
fc_op
,
fc_out
);
if
((
fc_input_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
||
(
fc_input_shape
.
size
()
!=
3
))
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
if
(
fc_input_shape
[
0
]
!=
multihead_matmul_input_shape
[
0
])
{
check_flag
=
false
;
}
if
(
fc_input_shape
[
1
]
!=
multihead_matmul_input_shape
[
1
])
{
check_flag
=
false
;
}
if
((
fc_input_shape
[
2
]
!=
multihead_matmul_input_shape
[
2
])
&&
(
fc_input_shape
[
2
]
!=
4
*
multihead_matmul_input_shape
[
2
]))
{
check_flag
=
false
;
}
if
(
BOOST_GET_CONST
(
int
,
fc_op
->
Op
()
->
GetAttr
(
"in_num_col_dims"
))
!=
2
)
{
check_flag
=
false
;
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
fc_op
->
Op
()
->
RemoveAttr
(
"in_num_col_dims"
);
fc_op
->
Op
()
->
SetAttr
(
"in_num_col_dims"
,
1
);
insert_remove_padding_op
(
fc_input
,
fc_op
);
insert_recover_padding_op
(
fc_op
,
fc_op
->
outputs
[
0
]);
found_subgraph_count
++
;
found_subgraph_count
++
;
};
};
gpd3
(
graph
,
handler3
);
gpd3
(
graph
,
handler3
);
...
@@ -280,6 +373,31 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -280,6 +373,31 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
activation_op
,
activation_op
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_op
,
activation_op
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
activation
);
std
::
vector
<
int64_t
>
activation_input_shape
=
activation_input
->
Var
()
->
GetShape
();
if
((
activation_input_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
||
(
activation_input_shape
.
size
()
!=
3
))
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
if
(
activation_input_shape
[
0
]
!=
multihead_matmul_input_shape
[
0
])
{
check_flag
=
false
;
}
if
(
activation_input_shape
[
1
]
!=
multihead_matmul_input_shape
[
1
])
{
check_flag
=
false
;
}
if
((
activation_input_shape
[
2
]
!=
multihead_matmul_input_shape
[
2
])
&&
(
activation_input_shape
[
2
]
!=
4
*
multihead_matmul_input_shape
[
2
]))
{
check_flag
=
false
;
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
insert_remove_padding_op
(
activation_input
,
activation_op
);
insert_remove_padding_op
(
activation_input
,
activation_op
);
insert_recover_padding_op
(
activation_op
,
activation_out
);
insert_recover_padding_op
(
activation_op
,
activation_out
);
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
浏览文件 @
2810dfea
...
@@ -32,6 +32,14 @@ namespace paddle {
...
@@ -32,6 +32,14 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
struct
EmbEltwiseLayernorm
:
public
PatternBase
{
EmbEltwiseLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"emb_elt_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
emb_elt_layernorm_op
);
PATTERN_DECL_NODE
(
emb_elt_layernorm_out
);
};
struct
SkipLayernorm
:
public
PatternBase
{
struct
SkipLayernorm
:
public
PatternBase
{
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc
浏览文件 @
2810dfea
...
@@ -21,129 +21,134 @@
...
@@ -21,129 +21,134 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
SetTransformerInputConvertPass
::
SetTransformerInputConvertPass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
End
();
}
namespace
patterns
{
namespace
patterns
{
void
SetTransformerInputConvert
::
operator
()()
{
void
SetTransformerInputConvert
::
operator
()(
const
std
::
string
&
pos_id
)
{
std
::
unordered_set
<
std
::
string
>
lookup_table_ops
{
"lookup_table"
,
std
::
unordered_set
<
std
::
string
>
lookup_table_ops
{
"lookup_table"
,
"lookup_table_v2"
};
"lookup_table_v2"
};
// Create nodes for lookup_table1 op.
// Create nodes for lookup_table.
auto
*
lookup_table1_x
=
pattern
->
NewNode
(
lookup_table1_x_repr
())
auto
*
lookup_table_id
=
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
);
pattern
->
NewNode
(
lookup_table_id_repr
())
auto
*
lookup_table1_w
=
pattern
->
NewNode
(
lookup_table1_w_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
)
->
assert_is_ops_input
(
lookup_table_ops
,
"W"
);
->
assert_more
([
&
](
Node
*
node
)
{
return
node
->
Name
()
==
pos_id
;
});
auto
*
lookup_table1_op
=
auto
*
lookup_table_op
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
lookup_table_ops
);
pattern
->
NewNode
(
lookup_table_repr
())
->
assert_is_ops
(
lookup_table_ops
);
auto
*
lookup_table1_out
=
pattern
->
NewNode
(
lookup_table1_out_repr
())
->
assert_is_ops_output
(
lookup_table_ops
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
// Create nodes for lookup_table2 op.
auto
*
lookup_table2_x
=
pattern
->
NewNode
(
lookup_table2_x_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
);
auto
*
lookup_table2_w
=
pattern
->
NewNode
(
lookup_table2_w_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"W"
);
auto
*
lookup_table2_op
=
pattern
->
NewNode
(
lookup_table2_repr
())
->
assert_is_ops
(
lookup_table_ops
);
auto
*
lookup_table2_out
=
pattern
->
NewNode
(
lookup_table2_out_repr
())
->
assert_is_ops_output
(
lookup_table_ops
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
// Create nodes for elementwise_add op.
auto
*
elementwise_op
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_out
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsOutput
()
->
assert_is_only_output_of_op
(
"elementwise_add"
);
// links nodes.
// links nodes.
lookup_table1_op
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
lookup_table_op
->
LinksFrom
({
lookup_table_id
});
.
LinksTo
({
lookup_table1_out
});
lookup_table2_op
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
elementwise_op
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
.
LinksTo
({
elementwise_out
});
}
}
void
MultiheadMatmulOP
::
operator
()()
{
// Create nodes for multihead_matmul op.
auto
*
multihead_matmul
=
pattern
->
NewNode
(
multihead_matmul_repr
())
->
assert_is_op
(
"multihead_matmul"
);
auto
*
multihead_matmul_out
=
pattern
->
NewNode
(
multihead_matmul_out_repr
())
->
assert_is_op_output
(
"multihead_matmul"
,
"Out"
);
// links nodes.
multihead_matmul_out
->
LinksFrom
({
multihead_matmul
});
}
}
// namespace patterns
}
// namespace patterns
void
SetTransformerInputConvertPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
SetTransformerInputConvertPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
if
(
!
(
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
)
&&
with_dynamic_shape
&&
(
pos_id
!=
""
)))
{
VLOG
(
3
)
<<
"Transformer model need MultiheadMatmul, and "
"with_dynamic_shape. Stop this pass, "
"please reconfig."
;
return
;
}
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
int
found_subgraph_count
=
0
;
Node
*
transformer_input_convert_out0_node
;
GraphPatternDetector
gpd
;
Node
*
transformer_input_convert_out1_node
;
GraphPatternDetector
gpd0
;
patterns
::
SetTransformerInputConvert
fused_pattern
(
patterns
::
SetTransformerInputConvert
fused_pattern
(
gpd
.
mutable_pattern
(),
"transformer_input_convert_pass"
);
gpd0
.
mutable_pattern
(),
"transformer_input_convert_pass"
);
fused_pattern
();
fused_pattern
(
pos_id
);
auto
handler0
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Graph
*
graph
)
{
VLOG
(
3
)
if
(
!
IsCompat
(
subgraph
,
graph
))
{
<<
"transformer_input_convert_pass for pos_id, max_seqlen, mask_tensor"
;
LOG
(
WARNING
)
<<
"transformer_input_convert_pass in op compat failed."
;
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table
,
lookup_table
,
fused_pattern
);
return
;
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table_id
,
lookup_table_id
,
fused_pattern
);
}
VLOG
(
3
)
<<
"transformer_input_convert_pass for pos_id, max_seqlen"
;
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_x
,
lookup_table2_x
,
fused_pattern
);
// create op, var in graph
// create op, var in graph
OpDesc
new_desc
;
OpDesc
new_desc
(
lookup_table
->
Op
()
->
Block
());
new_desc
.
SetType
(
"transformer_input_convert"
);
new_desc
.
SetType
(
"transformer_input_convert"
);
// inputs
// inputs
new_desc
.
SetInput
(
"
X"
,
{
lookup_table2_x
->
Name
()});
new_desc
.
SetInput
(
"
Input"
,
{
lookup_table_id
->
Name
()});
// outputs
// outputs
std
::
vector
<
std
::
string
>
output_0
=
{
"pos_id_tensor"
};
std
::
vector
<
std
::
string
>
output_1
=
{
"max_seqlen_tensor"
};
new_desc
.
SetOutput
(
"PosId"
,
output_0
);
new_desc
.
SetOutput
(
"MaxSeqlen"
,
output_1
);
std
::
string
transformer_input_convert_out0_name
=
"pos_id_tensor"
;
std
::
string
transformer_input_convert_out0_name
=
"pos_id_tensor"
;
std
::
string
transformer_input_convert_out1_name
=
"max_seqlen_tensor"
;
std
::
string
transformer_input_convert_out1_name
=
"max_seqlen_tensor"
;
VarDesc
transformer_input_convert_out0
(
transformer_input_convert_out0_name
);
std
::
string
transformer_input_convert_out2_name
=
"mask_tensor"
;
VarDesc
transformer_input_convert_out1
(
transformer_input_convert_out1_name
);
std
::
vector
<
std
::
string
>
output_0
=
{
transformer_input_convert_out0_name
};
transformer_input_convert_out0
.
SetDataType
(
proto
::
VarType
::
INT32
);
std
::
vector
<
std
::
string
>
output_1
=
{
transformer_input_convert_out1_name
};
transformer_input_convert_out1
.
SetDataType
(
proto
::
VarType
::
INT32
);
std
::
vector
<
std
::
string
>
output_2
=
{
transformer_input_convert_out2_name
};
transformer_input_convert_out0
.
SetShape
({
-
1
});
new_desc
.
SetOutput
(
"PosId"
,
output_0
);
transformer_input_convert_out1
.
SetShape
({
-
1
});
new_desc
.
SetOutput
(
"MaxSeqlen"
,
output_1
);
transformer_input_convert_out0
.
SetPersistable
(
false
);
new_desc
.
SetOutput
(
"MaskTensor"
,
output_2
);
transformer_input_convert_out1
.
SetPersistable
(
false
);
auto
*
transformer_input_convert_out0
=
lookup_table
->
Op
()
->
Block
()
->
Var
(
transformer_input_convert_out0_name
);
auto
*
transformer_input_convert_out1
=
lookup_table
->
Op
()
->
Block
()
->
Var
(
transformer_input_convert_out1_name
);
auto
*
transformer_input_convert_out2
=
lookup_table
->
Op
()
->
Block
()
->
Var
(
transformer_input_convert_out2_name
);
transformer_input_convert_out0
->
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out1
->
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out2
->
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out0
->
SetShape
({
-
1
});
transformer_input_convert_out1
->
SetShape
({
-
1
});
transformer_input_convert_out2
->
SetShape
({
-
1
});
transformer_input_convert_out0
->
SetPersistable
(
false
);
transformer_input_convert_out1
->
SetPersistable
(
false
);
transformer_input_convert_out2
->
SetPersistable
(
false
);
auto
new_op_node
=
graph
->
CreateOpNode
(
&
new_desc
);
auto
new_op_node
=
graph
->
CreateOpNode
(
&
new_desc
);
auto
transformer_input_convert_out0_node
=
auto
transformer_input_convert_out0_node
=
graph
->
CreateVarNode
(
&
transformer_input_convert_out0
);
graph
->
CreateVarNode
(
transformer_input_convert_out0
);
auto
transformer_input_convert_out1_node
=
auto
transformer_input_convert_out1_node
=
graph
->
CreateVarNode
(
&
transformer_input_convert_out1
);
graph
->
CreateVarNode
(
transformer_input_convert_out1
);
auto
transformer_input_convert_out2_node
=
graph
->
CreateVarNode
(
transformer_input_convert_out2
);
// needn't create variable in scope
// needn't create variable in scope
IR_NODE_LINK_TO
(
lookup_table
2_x
,
new_op_node
);
IR_NODE_LINK_TO
(
lookup_table
_id
,
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out0_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out0_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out1_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out1_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out2_node
);
found_subgraph_count
++
;
};
gpd0
(
graph
,
handler0
);
GraphPatternDetector
gpd1
;
patterns
::
MultiheadMatmulOP
multihead_matmul_pattern
(
gpd1
.
mutable_pattern
(),
"transformer_input_convert_pass"
);
multihead_matmul_pattern
();
auto
handler1
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"link pos_id, max_seqlen to multihead_matmul."
;
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul
,
multihead_matmul
,
multihead_matmul_pattern
);
IR_NODE_LINK_TO
(
transformer_input_convert_out0_node
,
multihead_matmul
);
IR_NODE_LINK_TO
(
transformer_input_convert_out1_node
,
multihead_matmul
);
};
};
gpd1
(
graph
,
handler1
);
gpd
(
graph
,
handler
)
;
found_subgraph_count
++
;
AddStatis
(
found_subgraph_count
);
AddStatis
(
found_subgraph_count
);
}
}
...
@@ -153,9 +158,3 @@ void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -153,9 +158,3 @@ void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS
(
set_transformer_input_convert_pass
,
REGISTER_PASS
(
set_transformer_input_convert_pass
,
paddle
::
framework
::
ir
::
SetTransformerInputConvertPass
);
paddle
::
framework
::
ir
::
SetTransformerInputConvertPass
);
REGISTER_PASS_CAPABILITY
(
set_transformer_input_convert_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"lookup_table"
,
1
)
.
LE
(
"lookup_table_v2"
,
1
)
.
LE
(
"elementweise_add"
,
1
));
paddle/fluid/framework/ir/set_transformer_input_convert_pass.h
浏览文件 @
2810dfea
...
@@ -33,41 +33,36 @@ namespace framework {
...
@@ -33,41 +33,36 @@ namespace framework {
namespace
ir
{
namespace
ir
{
namespace
patterns
{
namespace
patterns
{
// in_var emb in_var emb
// in_var emb
// | | | |
// | |
// lookup_table lookup_table
// lookup_table
// | |
// |
// lkt_var lkt_var
// lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
//
struct
SetTransformerInputConvert
:
public
PatternBase
{
struct
SetTransformerInputConvert
:
public
PatternBase
{
SetTransformerInputConvert
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
SetTransformerInputConvert
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"transformer_input_convert"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"transformer_input_convert_pass"
)
{}
void
operator
()(
const
std
::
string
&
pos_id
);
// declare operator node's name
PATTERN_DECL_NODE
(
lookup_table
);
// declare variable node's name
PATTERN_DECL_NODE
(
lookup_table_id
);
};
struct
MultiheadMatmulOP
:
public
PatternBase
{
MultiheadMatmulOP
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"transformer_input_convert_pass"
)
{}
void
operator
()();
void
operator
()();
// declare operator node's name
// declare operator node's name
PATTERN_DECL_NODE
(
lookup_table1
);
PATTERN_DECL_NODE
(
multihead_matmul
);
PATTERN_DECL_NODE
(
lookup_table2
);
PATTERN_DECL_NODE
(
multihead_matmul_out
);
PATTERN_DECL_NODE
(
elementwise
);
// declare variable node's name
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1_out
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table2_w
);
PATTERN_DECL_NODE
(
lookup_table2_out
);
PATTERN_DECL_NODE
(
elementwise_out
);
};
};
}
// namespace patterns
}
// namespace patterns
class
SetTransformerInputConvertPass
:
public
FusePassBase
{
class
SetTransformerInputConvertPass
:
public
FusePassBase
{
public:
public:
SetTransformerInputConvertPass
()
;
SetTransformerInputConvertPass
()
{}
virtual
~
SetTransformerInputConvertPass
()
{}
virtual
~
SetTransformerInputConvertPass
()
{}
protected:
protected:
...
...
paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
0 → 100644
浏览文件 @
2810dfea
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
static
PDNode
*
create_emb_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
,
bool
is_persist
=
false
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
PDNode
*
node
=
pattern
->
NewNode
(
name
)
->
assert_is_ops_input
(
embedding_ops
,
arg
);
if
(
is_persist
)
return
node
->
assert_is_persistable_var
();
return
node
;
}
static
PDNode
*
create_emb_out_vars
(
PDPattern
*
pattern
,
const
std
::
string
&
name
,
const
std
::
string
&
arg
)
{
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
PDNode
*
node
=
pattern
->
NewNode
(
name
)
->
assert_is_only_output_of_ops
(
embedding_ops
)
->
assert_is_op_input
(
"elementwise_add"
,
arg
)
->
AsIntermediate
();
return
node
;
}
void
TrtEmbedding2Eltwise1Pattern
::
operator
()()
{
auto
*
lookup_table1_x
=
create_emb_vars
(
pattern
,
lookup_table1_x_repr
(),
"Ids"
);
auto
*
lookup_table2_x
=
create_emb_vars
(
pattern
,
lookup_table2_x_repr
(),
"Ids"
);
auto
*
lookup_table1_w
=
create_emb_vars
(
pattern
,
lookup_table1_w_repr
(),
"W"
,
true
);
auto
*
lookup_table2_w
=
create_emb_vars
(
pattern
,
lookup_table2_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
feed2
=
pattern
->
NewNode
(
feed2_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table2
=
pattern
->
NewNode
(
lookup_table2_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table1_out
=
create_emb_out_vars
(
pattern
,
lookup_table1_out_repr
(),
"X"
);
auto
*
lookup_table2_out
=
create_emb_out_vars
(
pattern
,
lookup_table2_out_repr
(),
"Y"
);
auto
*
eltwise_add
=
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
feed1
->
LinksTo
({
lookup_table1_x
});
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
feed2
->
LinksTo
({
lookup_table2_x
});
lookup_table2
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
.
LinksTo
({
eltwise_add_out
});
}
void
TrtEmbedding1Eltwise1Pattern
::
operator
()()
{
auto
*
lookup_table1_x
=
create_emb_vars
(
pattern
,
lookup_table1_x_repr
(),
"Ids"
);
auto
*
lookup_table1_w
=
create_emb_vars
(
pattern
,
lookup_table1_w_repr
(),
"W"
,
true
);
std
::
unordered_set
<
std
::
string
>
embedding_ops
{
"lookup_table"
,
"lookup_table_v2"
};
auto
*
feed1
=
pattern
->
NewNode
(
feed1_repr
())
->
assert_is_op
(
"feed"
);
auto
*
lookup_table1
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
embedding_ops
);
auto
*
lookup_table1_out
=
create_emb_out_vars
(
pattern
,
lookup_table1_out_repr
(),
"Y"
);
auto
*
eltwise_add
=
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltwise_add_in
=
pattern
->
NewNode
(
eltwise_add_in_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_is_op_output
(
"elementwise_add"
);
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
lookup_table1
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
feed1
->
LinksTo
({
lookup_table1_x
});
eltwise_add
->
LinksFrom
({
lookup_table1_out
,
eltwise_add_in
})
.
LinksTo
({
eltwise_add_out
});
}
void
TrtSkipLayerNorm
::
operator
()()
{
auto
*
eltwise_add
=
pattern
->
NewNode
(
eltwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltwise_add_out
=
pattern
->
NewNode
(
eltwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"layer_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_out
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
AsOutput
();
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
eltwise_add
->
LinksTo
({
eltwise_add_out
});
layer_norm
->
LinksFrom
({
eltwise_add_out
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
({
layer_norm_out
,
layer_norm_mean_var
,
layer_norm_variance_var
});
}
}
// namespace patterns
int
TrtEmbeddingEltwiseLayerNormFusePass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
/*const Scope* scope*/
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
Node
*
,
Node
*>>>
start_pattern_in_nodes
;
std
::
vector
<
Node
*>
start_pattern_out_node
;
std
::
vector
<
std
::
unordered_set
<
Node
*>>
start_pattern_remove_nodes
;
// Create pattern.
patterns
::
TrtEmbedding2Eltwise1Pattern
start_pattern
(
pattern
,
name_scope
+
"/start"
);
start_pattern
();
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_x
,
lookup_table1_x
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_x
,
lookup_table2_x
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_w
,
lookup_table1_w
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_w
,
lookup_table2_w
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2
,
lookup_table2
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_out
,
lookup_table2_out
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
start_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
start_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"Pass(TrtEmbedding2Eltwise1Pattern) in op compat failed."
;
return
;
}
std
::
vector
<
std
::
pair
<
Node
*
,
Node
*>>
ins
;
ins
.
push_back
(
std
::
make_pair
(
lookup_table1_x
,
lookup_table1_w
));
ins
.
push_back
(
std
::
make_pair
(
lookup_table2_x
,
lookup_table2_w
));
start_pattern_in_nodes
.
push_back
(
ins
);
start_pattern_out_node
.
push_back
(
eltwise_add_out
);
std
::
unordered_set
<
Node
*>
rm_nodes
;
rm_nodes
.
insert
({
lookup_table1
,
lookup_table2
,
lookup_table1_out
,
lookup_table2_out
,
eltwise_add
,
eltwise_add_out
});
start_pattern_remove_nodes
.
push_back
(
rm_nodes
);
};
gpd
(
graph
,
handler
);
std
::
vector
<
std
::
pair
<
Node
*
,
Node
*>>
inner_pattern_ins
;
std
::
vector
<
Node
*>
inner_pattern_tmp_in
;
std
::
vector
<
Node
*>
inner_pattern_out
;
std
::
vector
<
std
::
unordered_set
<
Node
*>>
inner_pattern_remove_nodes
;
GraphPatternDetector
gpd2
;
auto
*
pattern2
=
gpd2
.
mutable_pattern
();
patterns
::
TrtEmbedding1Eltwise1Pattern
second_pattern
(
pattern2
,
name_scope
+
"/second"
);
second_pattern
();
auto
handler2
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_x
,
lookup_table1_x
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_w
,
lookup_table1_w
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1
,
lookup_table1
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table1_out
,
lookup_table1_out
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_in
,
eltwise_add_in
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
second_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
second_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"Pass(TrtEmbedding1Eltwise1Pattern) in op compat failed."
;
return
;
}
auto
in
=
std
::
make_pair
(
lookup_table1_x
,
lookup_table1_w
);
inner_pattern_ins
.
push_back
(
in
);
inner_pattern_tmp_in
.
push_back
(
eltwise_add_in
);
inner_pattern_out
.
push_back
(
eltwise_add_out
);
std
::
unordered_set
<
Node
*>
rm_nodes
;
rm_nodes
.
insert
(
{
lookup_table1
,
lookup_table1_out
,
eltwise_add
,
eltwise_add_out
});
inner_pattern_remove_nodes
.
push_back
(
rm_nodes
);
};
gpd2
(
graph
,
handler2
);
std
::
vector
<
Node
*>
end_pattern_elt_out
;
std
::
vector
<
Node
*>
end_pattern_scales
;
std
::
vector
<
Node
*>
end_pattern_biases
;
std
::
vector
<
Node
*>
end_pattern_out
;
std
::
vector
<
Node
*>
end_patter_layernorms
;
std
::
vector
<
std
::
unordered_set
<
Node
*>>
end_pattern_remove_nodes
;
GraphPatternDetector
gpd3
;
auto
*
pattern3
=
gpd3
.
mutable_pattern
();
patterns
::
TrtSkipLayerNorm
skip_layernorm_pattern
(
pattern3
,
name_scope
+
"/third"
);
skip_layernorm_pattern
();
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add
,
eltwise_add
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_add_out
,
eltwise_add_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
skip_layernorm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
skip_layernorm_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"Pass(TrtSkipLayerNorm) in op compat failed."
;
return
;
}
end_pattern_elt_out
.
push_back
(
eltwise_add_out
);
std
::
unordered_set
<
Node
*>
rm_nodes
;
rm_nodes
.
insert
({
layer_norm
,
layer_norm_mean
,
layer_norm_variance
});
end_pattern_remove_nodes
.
push_back
(
rm_nodes
);
end_pattern_biases
.
push_back
(
layer_norm_bias
);
end_pattern_scales
.
push_back
(
layer_norm_scale
);
end_pattern_out
.
push_back
(
layer_norm_out
);
end_patter_layernorms
.
push_back
(
layer_norm
);
};
gpd3
(
graph
,
handler3
);
if
(
start_pattern_in_nodes
.
empty
()
||
end_pattern_elt_out
.
empty
())
{
return
0
;
}
// only reserve the subgraphs that in connected domains.
int
fusion_count
=
0
;
// fusion_id for (i, k, js)
std
::
vector
<
std
::
pair
<
size_t
,
std
::
pair
<
size_t
,
std
::
vector
<
size_t
>>>>
fusion_ids
;
for
(
size_t
i
=
0
;
i
<
start_pattern_in_nodes
.
size
();
++
i
)
{
Node
*
tmp
=
start_pattern_out_node
[
i
];
Node
*
old_tmp
=
nullptr
;
// get correct inner pattern node order.
std
::
vector
<
size_t
>
js
;
while
(
tmp
!=
old_tmp
)
{
old_tmp
=
tmp
;
for
(
size_t
j
=
0
;
j
<
inner_pattern_tmp_in
.
size
();
++
j
)
{
if
(
inner_pattern_tmp_in
[
j
]
==
tmp
)
{
tmp
=
inner_pattern_out
[
j
];
js
.
push_back
(
j
);
break
;
}
}
}
for
(
size_t
k
=
0
;
k
<
end_pattern_elt_out
.
size
();
++
k
)
{
if
(
tmp
==
end_pattern_elt_out
[
k
])
{
fusion_ids
.
push_back
(
std
::
make_pair
(
i
,
std
::
make_pair
(
k
,
js
)));
break
;
}
}
}
for
(
size_t
num
=
0
;
num
<
fusion_ids
.
size
();
++
num
)
{
int
i
=
fusion_ids
[
num
].
first
;
int
k
=
fusion_ids
[
num
].
second
.
first
;
std
::
vector
<
size_t
>
js
=
fusion_ids
[
num
].
second
.
second
;
std
::
vector
<
std
::
string
>
ids
;
std
::
vector
<
std
::
string
>
embs
;
for
(
size_t
iter
=
0
;
iter
<
start_pattern_in_nodes
[
i
].
size
();
++
iter
)
{
ids
.
push_back
(
start_pattern_in_nodes
[
i
][
iter
].
first
->
Name
());
embs
.
push_back
(
start_pattern_in_nodes
[
i
][
iter
].
second
->
Name
());
}
for
(
size_t
iter
=
0
;
iter
<
js
.
size
();
++
iter
)
{
ids
.
push_back
(
inner_pattern_ins
[
js
[
iter
]].
first
->
Name
());
embs
.
push_back
(
inner_pattern_ins
[
js
[
iter
]].
second
->
Name
());
}
OpDesc
new_op_desc
(
end_patter_layernorms
[
0
]
->
Op
()
->
Block
());
new_op_desc
.
SetType
(
"fused_embedding_eltwise_layernorm"
);
new_op_desc
.
SetInput
(
"Ids"
,
ids
);
new_op_desc
.
SetInput
(
"Embs"
,
embs
);
new_op_desc
.
SetInput
(
"WordId"
,
{
ids
[
0
]});
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
new_op_desc
.
SetInput
(
"PosId"
,
{
pos_id
});
new_op_desc
.
SetInput
(
"MaskId"
,
{
mask_id
});
}
else
{
new_op_desc
.
SetInput
(
"PosId"
,
{
ids
[
1
]});
}
if
(
ids
.
size
()
>
2
)
{
new_op_desc
.
SetInput
(
"SentId"
,
{
ids
[
2
]});
}
new_op_desc
.
SetInput
(
"WordEmbedding"
,
{
embs
[
0
]});
new_op_desc
.
SetInput
(
"PosEmbedding"
,
{
embs
[
1
]});
if
(
embs
.
size
()
>
2
)
{
new_op_desc
.
SetInput
(
"SentEmbedding"
,
{
embs
[
2
]});
}
new_op_desc
.
SetInput
(
"Bias"
,
{
end_pattern_biases
[
k
]
->
Name
()});
new_op_desc
.
SetInput
(
"Scale"
,
{
end_pattern_scales
[
k
]
->
Name
()});
new_op_desc
.
SetOutput
(
"Out"
,
{
end_pattern_out
[
k
]
->
Name
()});
new_op_desc
.
SetAttr
(
"epsilon"
,
end_patter_layernorms
[
k
]
->
Op
()
->
GetAttr
(
"epsilon"
));
if
(
end_patter_layernorms
[
k
]
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
new_op_desc
.
SetAttr
(
"enable_int8"
,
true
);
new_op_desc
.
SetAttr
(
"out_threshold"
,
end_patter_layernorms
[
k
]
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
auto
*
embedding_eltwise_layernorm
=
graph
->
CreateOpNode
(
&
new_op_desc
);
for
(
size_t
iter
=
0
;
iter
<
start_pattern_in_nodes
[
i
].
size
();
++
iter
)
{
IR_NODE_LINK_TO
(
start_pattern_in_nodes
[
i
][
iter
].
first
,
embedding_eltwise_layernorm
);
IR_NODE_LINK_TO
(
start_pattern_in_nodes
[
i
][
iter
].
second
,
embedding_eltwise_layernorm
);
}
for
(
size_t
iter
=
0
;
iter
<
js
.
size
();
++
iter
)
{
IR_NODE_LINK_TO
(
inner_pattern_ins
[
js
[
iter
]].
first
,
embedding_eltwise_layernorm
);
IR_NODE_LINK_TO
(
inner_pattern_ins
[
js
[
iter
]].
second
,
embedding_eltwise_layernorm
);
}
IR_NODE_LINK_TO
(
end_pattern_biases
[
k
],
embedding_eltwise_layernorm
);
IR_NODE_LINK_TO
(
end_pattern_scales
[
k
],
embedding_eltwise_layernorm
);
IR_NODE_LINK_TO
(
embedding_eltwise_layernorm
,
end_pattern_out
[
k
]);
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
;
marked_nodes
.
insert
(
start_pattern_remove_nodes
[
i
].
begin
(),
start_pattern_remove_nodes
[
i
].
end
());
marked_nodes
.
insert
(
end_pattern_remove_nodes
[
k
].
begin
(),
end_pattern_remove_nodes
[
k
].
end
());
for
(
size_t
iter
=
0
;
iter
<
js
.
size
();
++
iter
)
{
marked_nodes
.
insert
(
inner_pattern_remove_nodes
[
js
[
iter
]].
begin
(),
inner_pattern_remove_nodes
[
js
[
iter
]].
end
());
}
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
}
return
fusion_count
;
}
TrtEmbeddingEltwiseLayerNormFusePass
::
TrtEmbeddingEltwiseLayerNormFusePass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
End
();
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
}
void
TrtEmbeddingEltwiseLayerNormFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"trt_embedding_eltwise_layernorm_fuse_pass need: use_varseqlen, "
"with_dynamic_shape. Stop this pass, "
"please reconfig."
;
return
;
}
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
fusion_count
=
TrtEmbeddingEltwiseLayerNormFusePass
::
BuildFusion
(
graph
,
name_scope_
);
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
((
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
||
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
))
{
VLOG
(
3
)
<<
"start trt_embedding_eltwise_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"
));
}
graph
->
Set
(
kEmbEltwiseLayernormPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
trt_embedding_eltwise_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
TrtEmbeddingEltwiseLayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
trt_embedding_eltwise_layernorm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"lookup_table"
,
1
)
.
LE
(
"lookup_table_v2"
,
1
)
.
LE
(
"elementweise_add"
,
1
));
paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h
0 → 100644
浏览文件 @
2810dfea
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
// detect start pattern.
//
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct
TrtEmbedding2Eltwise1Pattern
:
public
PatternBase
{
TrtEmbedding2Eltwise1Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"embedding2_eltwise1"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
feed2
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table2_w
);
PATTERN_DECL_NODE
(
lookup_table1
);
PATTERN_DECL_NODE
(
lookup_table2
);
PATTERN_DECL_NODE
(
lookup_table1_out
);
PATTERN_DECL_NODE
(
lookup_table2_out
);
PATTERN_DECL_NODE
(
eltwise_add
);
PATTERN_DECL_NODE
(
eltwise_add_out
);
};
// detect repeats inner pattern
//
// elt_out_var in_var emb
// \ | |
// \ lookup_table
// \ |
// \ lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct
TrtEmbedding1Eltwise1Pattern
:
public
PatternBase
{
TrtEmbedding1Eltwise1Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"embedding1_eltwise1"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
feed1
);
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1
);
PATTERN_DECL_NODE
(
lookup_table1_out
);
PATTERN_DECL_NODE
(
eltwise_add_in
);
PATTERN_DECL_NODE
(
eltwise_add
);
PATTERN_DECL_NODE
(
eltwise_add_out
);
};
// detect end pattern
//
// elementwise_add
// |
// elt_out_var
// scale | bias
// \ | /
// layer_norm
//
struct
TrtSkipLayerNorm
:
public
PatternBase
{
TrtSkipLayerNorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
eltwise_add
);
PATTERN_DECL_NODE
(
eltwise_add_out
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_out
);
// Delete the mean and var nodes in the graph.
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
};
}
// namespace patterns
// The TrtEmbeddingEltwiseLayerNormFusePass detect the following pattern:
//
// inputs operator output
// --------------------------------------------------------------------
// (word, weights_0) lookup_table -> word_emb
// (pos, weights_1) lookup_table -> pos_emb
// (sent, weights_2) lookup_table -> sent_emb
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
//
// and then convert the corresponding subgraph to:
//
// (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) embedding_eltwise_layernorm -> layer_norm_out
//
//
// in_var emb_var in_var emb_var in_var emb_var in_var emb_var
// | | | | | | | |
// lookup_table lookup_table lookup_table ... lookup_table
// | | | |
// lkt_var lkt_var lkt_var lkt_var
// \ / | ... |
// elementwise_add | |
// \ / |
// elementwise_add |
// | |
// elt_var /
// \ /
// elementwise_add
// |
// layer_norm
class
TrtEmbeddingEltwiseLayerNormFusePass
:
public
FusePassBase
{
public:
TrtEmbeddingEltwiseLayerNormFusePass
();
virtual
~
TrtEmbeddingEltwiseLayerNormFusePass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
/*const Scope* scope*/
)
const
;
const
std
::
string
name_scope_
{
"trt_embedding_eltwise_layernorm_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
0 → 100644
浏览文件 @
2810dfea
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
static
void
ReplaceOutputVar
(
Node
*
op
,
Node
*
old_var
,
Node
*
new_var
)
{
if
(
op
->
IsOp
()
&&
op
->
Op
())
{
new_var
->
inputs
.
push_back
(
op
);
for
(
size_t
i
=
0
;
i
<
op
->
outputs
.
size
();
++
i
)
{
if
(
op
->
outputs
[
i
]
==
old_var
)
{
op
->
outputs
[
i
]
=
new_var
;
op
->
Op
()
->
RenameOutput
(
old_var
->
Name
(),
new_var
->
Name
());
}
}
}
}
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
TrtMultiHeadMatmulPattern
multihead_pattern
(
pattern
,
name_scope
);
multihead_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
// bool after_scale =
// BOOST_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale"));
// create multihead
OpDesc
multihead_op_desc
(
mul0
->
Op
()
->
Block
());
// create tmp tensor
VarDesc
k_var_desc
(
*
mul1_out
->
Var
());
k_var_desc
.
SetName
(
"K"
+
mul1_out
->
Name
());
auto
*
k_var_node
=
graph
->
CreateVarNode
(
&
k_var_desc
);
VarDesc
q_var_desc
(
*
mul0_out
->
Var
());
q_var_desc
.
SetName
(
"Q"
+
mul0_out
->
Name
());
auto
*
q_var_node
=
graph
->
CreateVarNode
(
&
q_var_desc
);
VarDesc
v_var_desc
(
*
mul2_out
->
Var
());
v_var_desc
.
SetName
(
"V"
+
mul2_out
->
Name
());
auto
*
v_var_node
=
graph
->
CreateVarNode
(
&
v_var_desc
);
auto
reshape_desc
=
reshape2
->
Op
();
int
head_number
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
)).
at
(
2
);
ReplaceOutputVar
(
mul0
,
mul0_out
,
q_var_node
);
ReplaceOutputVar
(
mul1
,
mul1_out
,
k_var_node
);
ReplaceOutputVar
(
mul2
,
mul2_out
,
v_var_node
);
multihead_op_desc
.
SetType
(
"multihead_matmul"
);
multihead_op_desc
.
SetInput
(
"Q"
,
{
q_var_node
->
Name
()});
multihead_op_desc
.
SetInput
(
"K"
,
{
k_var_node
->
Name
()});
multihead_op_desc
.
SetInput
(
"V"
,
{
v_var_node
->
Name
()});
multihead_op_desc
.
SetInput
(
"BiasQ"
,
{
eltadd0_b
->
Name
()});
multihead_op_desc
.
SetInput
(
"BiasK"
,
{
eltadd1_b
->
Name
()});
multihead_op_desc
.
SetInput
(
"BiasV"
,
{
eltadd2_b
->
Name
()});
multihead_op_desc
.
SetInput
(
"BiasQK"
,
{
eltadd_qk_b
->
Name
()});
multihead_op_desc
.
SetOutput
(
"Out"
,
{
reshape2_qkv_out
->
Name
()});
multihead_op_desc
.
SetAttr
(
"alpha"
,
scale_attr
);
multihead_op_desc
.
SetAttr
(
"head_number"
,
head_number
);
auto
*
multihead
=
graph
->
CreateOpNode
(
&
multihead_op_desc
);
IR_NODE_LINK_TO
(
q_var_node
,
multihead
);
IR_NODE_LINK_TO
(
k_var_node
,
multihead
);
IR_NODE_LINK_TO
(
v_var_node
,
multihead
);
IR_NODE_LINK_TO
(
eltadd0_b
,
multihead
);
IR_NODE_LINK_TO
(
eltadd1_b
,
multihead
);
IR_NODE_LINK_TO
(
eltadd2_b
,
multihead
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
multihead
);
IR_NODE_LINK_TO
(
multihead
,
reshape2_qkv_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0
,
mul0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1
,
mul1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
// dropout_qk, dropout_qk_out,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
mul0_out
,
mul1_out
,
mul2_out
,
reshape2_qkv
,
scale
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
PDNode
*
TrtMultiHeadMatmulPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"mul"
);
// First path with scale
auto
*
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul0_w_var
=
pattern
->
NewNode
(
mul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"mul"
,
"Y"
);
auto
*
mul0_out_var
=
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_op_output
(
"mul"
);
decltype
(
mul0
)
eltadd0
;
decltype
(
mul0
)
eltadd0_b_var
;
decltype
(
mul0
)
eltadd0_out_var
;
mul0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"scale"
);
auto
*
scale
=
pattern
->
NewNode
(
scale_repr
())
->
assert_is_op
(
"scale"
);
auto
*
scale_out_var
=
pattern
->
NewNode
(
scale_out_repr
())
->
assert_is_op_output
(
"scale"
);
scale_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
);
softmax_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
);
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_qkv_out_var
->
assert_is_op_input
(
"mul"
);
// Second path to matmul
auto
*
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul1_w_var
=
pattern
->
NewNode
(
mul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"mul"
,
"Y"
);
auto
*
mul1_out_var
=
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_op_output
(
"mul"
);
decltype
(
mul1
)
eltadd1
;
decltype
(
mul1
)
eltadd1_b_var
;
decltype
(
mul1
)
eltadd1_out_var
;
mul1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd1
=
pattern
->
NewNode
(
eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd1_b_var
=
pattern
->
NewNode
(
eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_1
=
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
);
// link to matmul qk
// Third path to matmul
auto
*
mul2
=
pattern
->
NewNode
(
mul2_repr
())
->
assert_is_op
(
"mul"
);
auto
*
mul2_w_var
=
pattern
->
NewNode
(
mul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"mul"
,
"Y"
);
auto
*
mul2_out_var
=
pattern
->
NewNode
(
mul2_out_repr
())
->
assert_is_op_output
(
"mul"
);
decltype
(
mul2
)
eltadd2
;
decltype
(
mul2
)
eltadd2_b_var
;
decltype
(
mul2
)
eltadd2_out_var
;
mul2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
);
// link to matmul qkv
// Q path
mul0
->
LinksFrom
({
input0
,
mul0_w_var
}).
LinksTo
({
mul0_out_var
});
eltadd0
->
LinksFrom
({
mul0_out_var
,
eltadd0_b_var
}).
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
scale
->
LinksFrom
({
transpose2_0_out_var
}).
LinksTo
({
scale_out_var
});
// K path
mul1
->
LinksFrom
({
input0
,
mul1_w_var
}).
LinksTo
({
mul1_out_var
});
eltadd1
->
LinksFrom
({
mul1_out_var
,
eltadd1_b_var
}).
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// compute q*k
matmul_qk
->
LinksFrom
({
scale_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// V path
mul2
->
LinksFrom
({
input0
,
mul2_w_var
}).
LinksTo
({
mul2_out_var
});
eltadd2
->
LinksFrom
({
mul2_out_var
,
eltadd2_b_var
}).
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// compute q*k*v
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
transpose2_2_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
return
transpose2_2_out_var
;
}
PDNode
*
TrtMultiHeadMatmulV3Pattern
::
operator
()()
{
std
::
unordered_set
<
std
::
string
>
matmul_ops
{
"matmul"
,
"matmul_v2"
};
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_ops_input
(
matmul_ops
);
// First path with scale
auto
*
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
mul0_w_var
=
pattern
->
NewNode
(
mul0_w_repr
())
->
AsInput
()
->
assert_is_ops_input
(
matmul_ops
,
"Y"
);
auto
*
mul0_out_var
=
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
decltype
(
mul0
)
eltadd0
;
decltype
(
mul0
)
eltadd0_b_var
;
decltype
(
mul0
)
eltadd0_out_var
;
mul0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_0_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
,
"X"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
);
softmax_qk_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
);
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_qkv_out_var
->
assert_is_ops_input
(
matmul_ops
);
// Second path to matmul
auto
*
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
mul1_w_var
=
pattern
->
NewNode
(
mul1_w_repr
())
->
AsInput
()
->
assert_is_ops_input
(
matmul_ops
,
"Y"
);
auto
*
mul1_out_var
=
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
decltype
(
mul1
)
eltadd1
;
decltype
(
mul1
)
eltadd1_b_var
;
decltype
(
mul1
)
eltadd1_out_var
;
mul1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd1
=
pattern
->
NewNode
(
eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd1_b_var
=
pattern
->
NewNode
(
eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_1
=
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_1_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
,
"Y"
);
// link to matmul qk
// Third path to matmul
auto
*
mul2
=
pattern
->
NewNode
(
mul2_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
mul2_w_var
=
pattern
->
NewNode
(
mul2_w_repr
())
->
AsInput
()
->
assert_is_ops_input
(
matmul_ops
,
"Y"
);
auto
*
mul2_out_var
=
pattern
->
NewNode
(
mul2_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
decltype
(
mul2
)
eltadd2
;
decltype
(
mul2
)
eltadd2_b_var
;
decltype
(
mul2
)
eltadd2_out_var
;
mul2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_2_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
);
// link to matmul qkv
// Q path
mul0
->
LinksFrom
({
input0
,
mul0_w_var
}).
LinksTo
({
mul0_out_var
});
eltadd0
->
LinksFrom
({
mul0_out_var
,
eltadd0_b_var
}).
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
// K path
mul1
->
LinksFrom
({
input0
,
mul1_w_var
}).
LinksTo
({
mul1_out_var
});
eltadd1
->
LinksFrom
({
mul1_out_var
,
eltadd1_b_var
}).
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// compute q*k
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// V path
mul2
->
LinksFrom
({
input0
,
mul2_w_var
}).
LinksTo
({
mul2_out_var
});
eltadd2
->
LinksFrom
({
mul2_out_var
,
eltadd2_b_var
}).
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// compute q*k*v
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
transpose2_2_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
return
transpose2_2_out_var
;
}
}
// namespace patterns
void
TrtMultiHeadMatmulFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
fusion_count
=
patterns
::
BuildFusion
(
graph
,
name_scope_
);
AddStatis
(
fusion_count
);
}
TrtMultiHeadMatmulV2FusePass
::
TrtMultiHeadMatmulV2FusePass
()
{
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumEQ
(
2
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumEQ
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
// QK(true) QKV(false)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
}
int
TrtMultiHeadMatmulV2FusePass
::
BuildFusionV2
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
TrtMultiHeadMatmulPattern
multihead_pattern
(
pattern
,
name_scope
);
multihead_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
scale
,
Node
*
scale_out
,
Node
*
softmax_qk
,
Node
*
eltadd0
,
Node
*
eltadd1
,
Node
*
eltadd2
,
Node
*
matmul_qk
,
Node
*
reshape2_qkv
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
scale
->
Op
()
->
GetAttr
(
"scale"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// bias (B * S * 3 * N * H) + bias (3 * N * H)
// Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H)
auto
*
wq_tensor
=
scope
->
FindVar
(
mul0_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wk_tensor
=
scope
->
FindVar
(
mul1_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
mul2_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wq_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
wk_data
=
wk_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
wv_data
=
wv_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bq_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bk_data
=
bk_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bv_data
=
bv_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
combined_w_dims
=
phi
::
make_ddim
({
wq_tensor
->
dims
()[
0
],
3
,
wq_tensor
->
dims
()[
1
]});
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
bq_tensor
->
dims
()[
0
]});
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
mul0_w
->
Var
();
combined_w_desc
->
SetShape
({
wq_tensor
->
dims
()[
0
],
3
,
wq_tensor
->
dims
()[
1
]});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
bq_tensor
->
dims
()[
0
]});
combined_bias_desc
->
SetPersistable
(
true
);
framework
::
LoDTensor
tmp_combined_w_tensor
;
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
auto
*
tmp_combined_w_data
=
tmp_combined_w_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
std
::
vector
<
float
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
int
dims_h
=
combined_w_dims
[
0
],
dims_w
=
combined_w_dims
[
2
];
// Combine the three fc weights together.
for
(
int
i
=
0
;
i
<
dims_h
;
i
++
)
{
for
(
int
j
=
0
;
j
<
3
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dims_w
;
k
++
)
{
int
out_index
=
i
*
(
3
*
dims_w
)
+
j
*
dims_w
+
k
;
int
in_index
=
i
*
dims_w
+
k
;
tmp_combined_w_data
[
out_index
]
=
w_vec
[
j
][
in_index
];
}
}
}
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
float
)
*
wq_tensor
->
numel
());
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
framework
::
LoDTensor
tmp_combined_bias_tensor
;
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
auto
*
tmp_combined_bias_data
=
tmp_combined_bias_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
float
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
float
)
*
bq_tensor
->
numel
());
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
auto
reshape_desc
=
reshape2
->
Op
();
int
head_number
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
)).
at
(
2
);
OpDesc
multihead_op_desc
(
mul0
->
Op
()
->
Block
());
multihead_op_desc
.
SetType
(
"multihead_matmul"
);
multihead_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
multihead_op_desc
.
SetInput
(
"W"
,
{
mul0_w
->
Name
()});
multihead_op_desc
.
SetInput
(
"Bias"
,
{
eltadd0_b
->
Name
()});
multihead_op_desc
.
SetInput
(
"BiasQK"
,
{
eltadd_qk_b
->
Name
()});
multihead_op_desc
.
SetOutput
(
"Out"
,
{
reshape2_qkv_out
->
Name
()});
multihead_op_desc
.
SetAttr
(
"alpha"
,
scale_attr
);
multihead_op_desc
.
SetAttr
(
"head_number"
,
head_number
);
auto
*
mul0_op_desc
=
mul0
->
Op
();
// all mul op has same input.
if
(
multihead_op_desc
.
HasAttr
(
"Input_scale"
))
{
multihead_op_desc
.
SetAttr
(
"Input_scale"
,
mul0_op_desc
->
GetAttr
(
"Input_scale"
));
}
auto
*
add0_op_desc
=
eltadd0
->
Op
();
auto
*
add1_op_desc
=
eltadd1
->
Op
();
auto
*
add2_op_desc
=
eltadd2
->
Op
();
if
(
add0_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
out_scale0
=
BOOST_GET_CONST
(
float
,
add0_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale1
=
BOOST_GET_CONST
(
float
,
add1_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale2
=
BOOST_GET_CONST
(
float
,
add2_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale_max
=
std
::
max
(
out_scale0
,
out_scale1
);
out_scale_max
=
std
::
max
(
out_scale_max
,
out_scale2
);
multihead_op_desc
.
SetAttr
(
"fc_out_threshold"
,
out_scale_max
);
}
auto
*
softmax_qk_op_desc
=
softmax_qk
->
Op
();
auto
*
matmul_qk_op_desc
=
matmul_qk
->
Op
();
if
(
matmul_qk_op_desc
->
HasAttr
(
"Input_scale"
))
{
multihead_op_desc
.
SetAttr
(
"qkv2context_plugin_int8"
,
true
);
if
(
softmax_qk_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
qkv_plugin_scale
=
BOOST_GET_CONST
(
float
,
softmax_qk_op_desc
->
GetAttr
(
"out_threshold"
));
multihead_op_desc
.
SetAttr
(
"dp_probs"
,
qkv_plugin_scale
);
}
}
if
(
reshape2_qkv
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
multihead_op_desc
.
SetAttr
(
"out_threshold"
,
reshape2_qkv
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
auto
*
multihead
=
graph
->
CreateOpNode
(
&
multihead_op_desc
);
IR_NODE_LINK_TO
(
input0
,
multihead
);
IR_NODE_LINK_TO
(
mul0_w
,
multihead
);
IR_NODE_LINK_TO
(
eltadd0_b
,
multihead
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
multihead
);
IR_NODE_LINK_TO
(
multihead
,
reshape2_qkv_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Op compat check in trt_multihead_matmul_fuse_pass_v2 failed."
;
return
;
}
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0
,
mul0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1
,
mul1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool
is_fc_params_shared
=
mul0_w
->
outputs
.
size
()
>
1
||
mul1_w
->
outputs
.
size
()
>
1
||
mul2_w
->
outputs
.
size
()
>
1
||
eltadd0_b
->
outputs
.
size
()
>
1
||
eltadd1_b
->
outputs
.
size
()
>
1
||
eltadd2_b
->
outputs
.
size
()
>
1
;
if
(
is_fc_params_shared
)
{
return
;
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
scale
,
scale_out
,
softmax_qk
,
eltadd0
,
eltadd1
,
eltadd2
,
matmul_qk
,
reshape2_qkv
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
eltadd2
,
eltadd1_b
,
eltadd2_b
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul1_w
,
mul2_w
,
reshape2_qkv
,
scale
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
TrtMultiHeadMatmulV2FusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multiheadMatmul pass, The scope should not be null."
));
int
fusion_count
=
BuildFusionV2
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_multihead_matmul_fuse_pass_v2"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_multihead_matmul_fuse_pass_v2"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"
));
}
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
TrtMultiHeadMatmulV3FusePass
::
TrtMultiHeadMatmulV3FusePass
()
{
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumEQ
(
2
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsType
<
float
>
()
// QK(anyvalue, will copy to new op) QKV(1.0)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
// QK(true) QKV(false)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"trans_y"
)
// QK(true) QKV(false)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
}
int
TrtMultiHeadMatmulV3FusePass
::
BuildFusionV3
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
TrtMultiHeadMatmulV3Pattern
multihead_pattern
(
pattern
,
name_scope
);
multihead_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
mul0
,
Node
*
mul1
,
Node
*
mul2
,
Node
*
mul0_out
,
Node
*
mul1_out
,
Node
*
mul2_out
,
Node
*
mul0_w
,
Node
*
mul1_w
,
Node
*
mul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
eltadd_qk_b
,
Node
*
reshape2
,
Node
*
reshape2_qkv_out
,
Node
*
matmul_qk
)
{
auto
scale_attr
=
BOOST_GET_CONST
(
float
,
matmul_qk
->
Op
()
->
GetAttr
(
"alpha"
));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// bias (B * S * 3 * N * H) + bias (3 * N * H)
// Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H)
auto
*
wq_tensor
=
scope
->
FindVar
(
mul0_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wk_tensor
=
scope
->
FindVar
(
mul1_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
mul2_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wq_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
wk_data
=
wk_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
wv_data
=
wv_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bq_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bk_data
=
bk_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bv_data
=
bv_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
combined_w_dims
=
phi
::
make_ddim
({
wq_tensor
->
dims
()[
0
],
3
,
wq_tensor
->
dims
()[
1
]});
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
bq_tensor
->
dims
()[
0
]});
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
mul0_w
->
Var
();
combined_w_desc
->
SetShape
({
wq_tensor
->
dims
()[
0
],
3
,
wq_tensor
->
dims
()[
1
]});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
bq_tensor
->
dims
()[
0
]});
combined_bias_desc
->
SetPersistable
(
true
);
framework
::
LoDTensor
tmp_combined_w_tensor
;
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
auto
*
tmp_combined_w_data
=
tmp_combined_w_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
std
::
vector
<
float
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
int
dims_h
=
combined_w_dims
[
0
],
dims_w
=
combined_w_dims
[
2
];
// Combine the three fc weights together.
for
(
int
i
=
0
;
i
<
dims_h
;
i
++
)
{
for
(
int
j
=
0
;
j
<
3
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dims_w
;
k
++
)
{
int
out_index
=
i
*
(
3
*
dims_w
)
+
j
*
dims_w
+
k
;
int
in_index
=
i
*
dims_w
+
k
;
tmp_combined_w_data
[
out_index
]
=
w_vec
[
j
][
in_index
];
}
}
}
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
float
)
*
wq_tensor
->
numel
());
scope
->
EraseVars
({
mul1_w
->
Name
(),
mul2_w
->
Name
()});
framework
::
LoDTensor
tmp_combined_bias_tensor
;
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
auto
*
tmp_combined_bias_data
=
tmp_combined_bias_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
float
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
float
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
float
)
*
bq_tensor
->
numel
());
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
auto
reshape_desc
=
reshape2
->
Op
();
int
head_number
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
)).
at
(
2
);
OpDesc
multihead_op_desc
(
mul0
->
Op
()
->
Block
());
multihead_op_desc
.
SetType
(
"multihead_matmul"
);
multihead_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
multihead_op_desc
.
SetInput
(
"W"
,
{
mul0_w
->
Name
()});
multihead_op_desc
.
SetInput
(
"Bias"
,
{
eltadd0_b
->
Name
()});
multihead_op_desc
.
SetInput
(
"BiasQK"
,
{
eltadd_qk_b
->
Name
()});
multihead_op_desc
.
SetOutput
(
"Out"
,
{
reshape2_qkv_out
->
Name
()});
multihead_op_desc
.
SetAttr
(
"alpha"
,
scale_attr
);
multihead_op_desc
.
SetAttr
(
"head_number"
,
head_number
);
auto
*
multihead
=
graph
->
CreateOpNode
(
&
multihead_op_desc
);
IR_NODE_LINK_TO
(
input0
,
multihead
);
IR_NODE_LINK_TO
(
mul0_w
,
multihead
);
IR_NODE_LINK_TO
(
eltadd0_b
,
multihead
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
multihead
);
IR_NODE_LINK_TO
(
multihead
,
reshape2_qkv_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0
,
mul0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_out
,
mul0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul0_w
,
mul0_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1
,
mul1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_out
,
mul1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul1_w
,
mul1_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2
,
mul2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_out
,
mul2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul2_w
,
mul2_w
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
multihead_pattern
);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
multihead_pattern
);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool
is_fc_params_shared
=
mul0_w
->
outputs
.
size
()
>
1
||
mul1_w
->
outputs
.
size
()
>
1
||
mul2_w
->
outputs
.
size
()
>
1
||
eltadd0_b
->
outputs
.
size
()
>
1
||
eltadd1_b
->
outputs
.
size
()
>
1
||
eltadd2_b
->
outputs
.
size
()
>
1
;
if
(
is_fc_params_shared
)
{
return
;
}
fuse_creater
(
input0
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul0_w
,
mul1_w
,
mul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
eltadd_qk_b
,
reshape2_0
,
reshape2_qkv_out
,
matmul_qk
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
eltadd0
,
eltadd1
,
eltadd2
,
eltadd1_b
,
eltadd2_b
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
mul0
,
mul1
,
mul2
,
mul0_out
,
mul1_out
,
mul2_out
,
mul1_w
,
mul2_w
,
reshape2_qkv
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
TrtMultiHeadMatmulV3FusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multiheadMatmul pass, The scope should not be null."
));
int
fusion_count
=
BuildFusionV3
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_multihead_matmul_fuse_pass_v3"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_multihead_matmul_fuse_pass_v3"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"
));
}
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
trt_multihead_matmul_fuse_pass
,
paddle
::
framework
::
ir
::
TrtMultiHeadMatmulFusePass
);
REGISTER_PASS
(
trt_multihead_matmul_fuse_pass_v2
,
paddle
::
framework
::
ir
::
TrtMultiHeadMatmulV2FusePass
);
REGISTER_PASS
(
trt_multihead_matmul_fuse_pass_v3
,
paddle
::
framework
::
ir
::
TrtMultiHeadMatmulV3FusePass
);
REGISTER_PASS_CAPABILITY
(
trt_multihead_matmul_fuse_pass_v2
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"mul"
,
0
)
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
trt_multihead_matmul_fuse_pass_v3
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h
0 → 100644
浏览文件 @
2810dfea
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
TrtMultiHeadMatmulPattern
:
public
PatternBase
{
TrtMultiHeadMatmulPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
PDNode
*
operator
()();
// declare operator node's name
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
mul0
);
PATTERN_DECL_NODE
(
mul1
);
PATTERN_DECL_NODE
(
mul2
);
PATTERN_DECL_NODE
(
mul0_w
);
PATTERN_DECL_NODE
(
mul1_w
);
PATTERN_DECL_NODE
(
mul2_w
);
PATTERN_DECL_NODE
(
mul0_out
);
PATTERN_DECL_NODE
(
mul1_out
);
PATTERN_DECL_NODE
(
mul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
PATTERN_DECL_NODE
(
scale
);
PATTERN_DECL_NODE
(
scale_out
);
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
};
struct
TrtMultiHeadMatmulV3Pattern
:
public
PatternBase
{
TrtMultiHeadMatmulV3Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul_v3"
)
{}
PDNode
*
operator
()();
// declare operator node's name
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
mul0
);
PATTERN_DECL_NODE
(
mul1
);
PATTERN_DECL_NODE
(
mul2
);
PATTERN_DECL_NODE
(
mul0_w
);
PATTERN_DECL_NODE
(
mul1_w
);
PATTERN_DECL_NODE
(
mul2_w
);
PATTERN_DECL_NODE
(
mul0_out
);
PATTERN_DECL_NODE
(
mul1_out
);
PATTERN_DECL_NODE
(
mul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
};
}
// namespace patterns
class
TrtMultiHeadMatmulFusePass
:
public
FusePassBase
{
public:
virtual
~
TrtMultiHeadMatmulFusePass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"trt_multihead_matmul_fuse"
};
};
class
TrtMultiHeadMatmulV2FusePass
:
public
FusePassBase
{
public:
TrtMultiHeadMatmulV2FusePass
();
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"trt_multihead_matmul_fuse_v2"
};
private:
int
BuildFusionV2
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
TrtMultiHeadMatmulV3FusePass
:
public
FusePassBase
{
public:
TrtMultiHeadMatmulV3FusePass
();
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"trt_multihead_matmul_fuse_v3"
};
private:
int
BuildFusionV3
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
TrtSkipLayerNorm
:
public
PatternBase
{
TrtSkipLayerNorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
PDNode
*
y
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
layer_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise_out
);
// (elementwise_input_x,elementwise_input_y) ->
// elementwise_out
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
};
PDNode
*
TrtSkipLayerNorm
::
operator
()(
PDNode
*
x
,
PDNode
*
y
)
{
// Create nodes for elementwise add op.
x
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
y
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
elementwise
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_out_var
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsOutput
()
->
assert_is_only_output_of_op
(
"elementwise_add"
);
// Add links for elementwise_add op.
elementwise
->
LinksFrom
({
x
,
y
}).
LinksTo
({
elementwise_out_var
});
// Create nodes for layer_norm op.
elementwise_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"layer_norm"
);
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
// Add links for layer_norm op.
layer_norm
->
LinksFrom
(
{
elementwise_out_var
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
return
layer_norm_out_var
;
}
}
// namespace patterns
void
TrtSkipLayerNormFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"skip_layernorm_fuse"
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"skip_layernorm_fuse/x"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_var_not_persistable
();
auto
*
y
=
gpd
.
mutable_pattern
()
->
NewNode
(
"skip_layernorm_fuse/y"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_var_not_persistable
();
patterns
::
TrtSkipLayerNorm
fused_pattern
(
gpd
.
mutable_pattern
(),
"skip_layernorm_fuse"
);
fused_pattern
(
x
,
y
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
subgraph
.
count
(
x
)
<=
0
||
subgraph
.
count
(
y
)
<=
0
)
{
LOG
(
WARNING
)
<<
"The subgraph is empty."
;
return
;
}
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"skip_layernorm pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle TrtSkipLayerNorm fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_pattern
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an TrtSkipLayerNorm op node
OpDesc
new_desc
(
elementwise
->
Op
()
->
Block
());
new_desc
.
SetType
(
"skip_layernorm"
);
// inputs
new_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
new_desc
.
SetInput
(
"Y"
,
{
subgraph
.
at
(
y
)
->
Name
()});
new_desc
.
SetInput
(
"Scale"
,
{
layer_norm_scale
->
Name
()});
new_desc
.
SetInput
(
"Bias"
,
{
layer_norm_bias
->
Name
()});
if
(
layer_norm
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
new_desc
.
SetAttr
(
"enable_int8"
,
true
);
new_desc
.
SetAttr
(
"out_threshold"
,
layer_norm
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
// outputs
new_desc
.
SetOutput
(
"Out"
,
{
layer_norm_out
->
Name
()});
// attrs
new_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
new_desc
.
SetAttr
(
"begin_norm_axis"
,
layer_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
del_node_set
.
insert
(
elementwise
);
del_node_set
.
insert
(
layer_norm
);
del_node_set
.
insert
(
elementwise_out
);
del_node_set
.
insert
(
layer_norm_mean
);
del_node_set
.
insert
(
layer_norm_variance
);
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
y
),
fused_node
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
layer_norm_out
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
if
(
found_subgraph_count
>
0
)
{
bool
use_varseqlen
=
Get
<
bool
>
(
"use_varseqlen"
);
std
::
string
pos_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
);
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
(
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
{
if
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
{
VLOG
(
3
)
<<
"start varseqlen trt_skip_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_skip_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"
));
}
}
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
trt_skip_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
TrtSkipLayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
trt_skip_layernorm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"layer_norm"
,
0
));
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> skip_layernorm
// | |
// layer_norm other_op3
// | |
// other_op3
// |
class
Graph
;
class
TrtSkipLayerNormFusePass
:
public
FusePassBase
{
public:
TrtSkipLayerNormFusePass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
0
,
-
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
}
virtual
~
TrtSkipLayerNormFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/analysis/argument.h
浏览文件 @
2810dfea
...
@@ -216,8 +216,12 @@ struct Argument {
...
@@ -216,8 +216,12 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
tensorrt_use_static_engine
,
TensorRtUseStaticEngine
,
DECL_ARGUMENT_FIELD
(
tensorrt_use_static_engine
,
TensorRtUseStaticEngine
,
bool
);
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_calib_mode
,
TensorRtUseCalibMode
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_calib_mode
,
TensorRtUseCalibMode
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_
oss
,
TensorRtUseOSS
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_
varseqlen
,
TensorRtUseOSS
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_with_interleaved
,
TensorRtWithInterleaved
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_with_interleaved
,
TensorRtWithInterleaved
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_transformer_posid
,
TensorRtTransformerPosid
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
tensorrt_transformer_maskid
,
TensorRtTransformerMaskid
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
tensorrt_shape_range_info_path
,
DECL_ARGUMENT_FIELD
(
tensorrt_shape_range_info_path
,
TensorRtShapeRangeInfoPath
,
std
::
string
);
TensorRtShapeRangeInfoPath
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
tensorrt_tuned_dynamic_shape
,
TensorRtTunedDynamicShape
,
DECL_ARGUMENT_FIELD
(
tensorrt_tuned_dynamic_shape
,
TensorRtTunedDynamicShape
,
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
2810dfea
...
@@ -55,9 +55,13 @@ void IRPassManager::CreatePasses(Argument *argument,
...
@@ -55,9 +55,13 @@ void IRPassManager::CreatePasses(Argument *argument,
int
pass_num
=
0
;
int
pass_num
=
0
;
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
pass
->
Set
(
"use_
oss"
,
new
bool
(
argument
->
tensorrt_use_oss
()));
pass
->
Set
(
"use_
varseqlen"
,
new
bool
(
argument
->
tensorrt_use_varseqlen
()));
pass
->
Set
(
"with_interleaved"
,
pass
->
Set
(
"with_interleaved"
,
new
bool
(
argument
->
tensorrt_with_interleaved
()));
new
bool
(
argument
->
tensorrt_with_interleaved
()));
pass
->
Set
(
"tensorrt_transformer_posid"
,
new
std
::
string
(
argument
->
tensorrt_transformer_posid
()));
pass
->
Set
(
"tensorrt_transformer_maskid"
,
new
std
::
string
(
argument
->
tensorrt_transformer_maskid
()));
pass
->
Set
(
"disable_logs"
,
new
bool
(
argument
->
disable_logs
()));
pass
->
Set
(
"disable_logs"
,
new
bool
(
argument
->
disable_logs
()));
auto
precision_mode
=
argument
->
tensorrt_precision_mode
();
auto
precision_mode
=
argument
->
tensorrt_precision_mode
();
bool
enable_int8
=
precision_mode
==
AnalysisConfig
::
Precision
::
kInt8
;
bool
enable_int8
=
precision_mode
==
AnalysisConfig
::
Precision
::
kInt8
;
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
2810dfea
...
@@ -377,12 +377,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
...
@@ -377,12 +377,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
Get
<
int
>
(
"workspace_size"
),
precision_mode
,
calibrator
.
get
(),
Get
<
int
>
(
"workspace_size"
),
precision_mode
,
calibrator
.
get
(),
Get
<
int
>
(
"gpu_device_id"
),
min_input_shape
,
max_input_shape
,
Get
<
int
>
(
"gpu_device_id"
),
min_input_shape
,
max_input_shape
,
opt_input_shape
,
disable_trt_plugin_fp16
);
opt_input_shape
,
disable_trt_plugin_fp16
);
trt_engine
->
SetUseOSS
(
Get
<
bool
>
(
"use_
oss
"
));
trt_engine
->
SetUseOSS
(
Get
<
bool
>
(
"use_
varseqlen
"
));
trt_engine
->
SetWithInterleaved
(
Get
<
bool
>
(
"with_interleaved"
));
trt_engine
->
SetWithInterleaved
(
Get
<
bool
>
(
"with_interleaved"
));
trt_engine
->
SetTransformerPosid
(
Get
<
std
::
string
>
(
"tensorrt_transformer_posid"
));
trt_engine
->
SetTransformerMaskid
(
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
));
trt_engine
->
SetUseDLA
(
Get
<
bool
>
(
"trt_use_dla"
));
trt_engine
->
SetUseDLA
(
Get
<
bool
>
(
"trt_use_dla"
));
trt_engine
->
SetDLACore
(
Get
<
int
>
(
"trt_dla_core"
));
trt_engine
->
SetDLACore
(
Get
<
int
>
(
"trt_dla_core"
));
trt_engine
->
SetUseInspector
(
Get
<
bool
>
(
"use_inspector"
));
trt_engine
->
SetUseInspector
(
Get
<
bool
>
(
"use_inspector"
));
trt_engine
->
SetWithErnie
(
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
));
trt_engine
->
SetWithErnie
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
));
if
(
use_static_engine
)
{
if
(
use_static_engine
)
{
trt_engine_serialized_data
=
GetTrtEngineSerializedData
(
trt_engine_serialized_data
=
GetTrtEngineSerializedData
(
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
2810dfea
...
@@ -256,8 +256,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
...
@@ -256,8 +256,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER
(
trt_dla_core_
);
CP_MEMBER
(
trt_dla_core_
);
CP_MEMBER
(
trt_use_static_engine_
);
CP_MEMBER
(
trt_use_static_engine_
);
CP_MEMBER
(
trt_use_calib_mode_
);
CP_MEMBER
(
trt_use_calib_mode_
);
CP_MEMBER
(
trt_use_
oss
_
);
CP_MEMBER
(
trt_use_
varseqlen
_
);
CP_MEMBER
(
trt_with_interleaved_
);
CP_MEMBER
(
trt_with_interleaved_
);
CP_MEMBER
(
tensorrt_transformer_posid_
);
CP_MEMBER
(
tensorrt_transformer_maskid_
);
CP_MEMBER
(
trt_tuned_dynamic_shape_
);
CP_MEMBER
(
trt_tuned_dynamic_shape_
);
CP_MEMBER
(
trt_allow_build_at_runtime_
);
CP_MEMBER
(
trt_allow_build_at_runtime_
);
CP_MEMBER
(
collect_shape_range_info_
);
CP_MEMBER
(
collect_shape_range_info_
);
...
@@ -546,7 +548,7 @@ void AnalysisConfig::Exp_DisableTensorRtOPs(
...
@@ -546,7 +548,7 @@ void AnalysisConfig::Exp_DisableTensorRtOPs(
trt_disabled_ops_
.
insert
(
trt_disabled_ops_
.
end
(),
ops
.
begin
(),
ops
.
end
());
trt_disabled_ops_
.
insert
(
trt_disabled_ops_
.
end
(),
ops
.
begin
(),
ops
.
end
());
}
}
void
AnalysisConfig
::
Enable
TensorRtOSS
()
{
trt_use_oss
_
=
true
;
}
void
AnalysisConfig
::
Enable
Varseqlen
()
{
trt_use_varseqlen
_
=
true
;
}
// TODO(Superjomn) refactor this, buggy.
// TODO(Superjomn) refactor this, buggy.
void
AnalysisConfig
::
Update
()
{
void
AnalysisConfig
::
Update
()
{
...
@@ -1034,9 +1036,13 @@ std::string AnalysisConfig::Summary() {
...
@@ -1034,9 +1036,13 @@ std::string AnalysisConfig::Summary() {
?
shape_range_info_path_
?
shape_range_info_path_
:
"false"
});
:
"false"
});
os
.
InsertRow
({
"tensorrt_use_oss"
,
trt_use_oss_
?
"true"
:
"false"
});
os
.
InsertRow
(
{
"tensorrt_use_varseqlen"
,
trt_use_varseqlen_
?
"true"
:
"false"
});
os
.
InsertRow
({
"tensorrt_with_interleaved"
,
os
.
InsertRow
({
"tensorrt_with_interleaved"
,
trt_with_interleaved_
?
"true"
:
"false"
});
trt_with_interleaved_
?
"true"
:
"false"
});
os
.
InsertRow
({
"tensorrt_transformer_posid"
,
tensorrt_transformer_posid_
});
os
.
InsertRow
(
{
"tensorrt_transformer_maskid"
,
tensorrt_transformer_maskid_
});
os
.
InsertRow
({
"tensorrt_use_dla"
,
trt_use_dla_
?
"true"
:
"false"
});
os
.
InsertRow
({
"tensorrt_use_dla"
,
trt_use_dla_
?
"true"
:
"false"
});
if
(
trt_use_dla_
)
{
if
(
trt_use_dla_
)
{
os
.
InsertRow
({
"tensorrt_dla_core"
,
std
::
to_string
(
trt_dla_core_
)});
os
.
InsertRow
({
"tensorrt_dla_core"
,
std
::
to_string
(
trt_dla_core_
)});
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
2810dfea
...
@@ -853,8 +853,10 @@ void AnalysisPredictor::PrepareArgument() {
...
@@ -853,8 +853,10 @@ void AnalysisPredictor::PrepareArgument() {
}
}
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
argument_
.
SetTensorRtUseOSS
(
config_
.
trt_use_
oss
_
);
argument_
.
SetTensorRtUseOSS
(
config_
.
trt_use_
varseqlen
_
);
argument_
.
SetTensorRtWithInterleaved
(
config_
.
trt_with_interleaved_
);
argument_
.
SetTensorRtWithInterleaved
(
config_
.
trt_with_interleaved_
);
argument_
.
SetTensorRtTransformerPosid
(
config_
.
tensorrt_transformer_posid_
);
argument_
.
SetTensorRtTransformerMaskid
(
config_
.
tensorrt_transformer_maskid_
);
argument_
.
SetMinInputShape
(
config_
.
min_input_shape_
);
argument_
.
SetMinInputShape
(
config_
.
min_input_shape_
);
argument_
.
SetMaxInputShape
(
config_
.
max_input_shape_
);
argument_
.
SetMaxInputShape
(
config_
.
max_input_shape_
);
argument_
.
SetOptimInputShape
(
config_
.
optim_input_shape_
);
argument_
.
SetOptimInputShape
(
config_
.
optim_input_shape_
);
...
@@ -1803,6 +1805,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
...
@@ -1803,6 +1805,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER
(
preln_skip_layernorm
)
USE_TRT_CONVERTER
(
preln_skip_layernorm
)
USE_TRT_CONVERTER
(
roll
)
USE_TRT_CONVERTER
(
roll
)
USE_TRT_CONVERTER
(
strided_slice
)
USE_TRT_CONVERTER
(
strided_slice
)
USE_TRT_CONVERTER
(
transformer_input_convert
)
USE_TRT_CONVERTER
(
recover_padding
)
USE_TRT_CONVERTER
(
remove_padding
)
#endif
#endif
namespace
paddle_infer
{
namespace
paddle_infer
{
...
@@ -1971,6 +1976,20 @@ void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c,
...
@@ -1971,6 +1976,20 @@ void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c,
#endif
#endif
}
}
void
InternalUtils
::
SetTransformerPosid
(
paddle_infer
::
Config
*
c
,
const
std
::
string
&
tensorrt_transformer_posid
)
{
#ifdef PADDLE_WITH_CUDA
c
->
tensorrt_transformer_posid_
=
tensorrt_transformer_posid
;
#endif
}
void
InternalUtils
::
SetTransformerMaskid
(
paddle_infer
::
Config
*
c
,
const
std
::
string
&
tensorrt_transformer_maskid
)
{
#ifdef PADDLE_WITH_CUDA
c
->
tensorrt_transformer_maskid_
=
tensorrt_transformer_maskid
;
#endif
}
void
InternalUtils
::
SyncStream
(
paddle_infer
::
Predictor
*
p
)
{
void
InternalUtils
::
SyncStream
(
paddle_infer
::
Predictor
*
p
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
auto
*
pred
=
dynamic_cast
<
paddle
::
AnalysisPredictor
*>
(
p
->
predictor_
.
get
());
auto
*
pred
=
dynamic_cast
<
paddle
::
AnalysisPredictor
*>
(
p
->
predictor_
.
get
());
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
2810dfea
...
@@ -618,14 +618,14 @@ struct PD_INFER_DECL AnalysisConfig {
...
@@ -618,14 +618,14 @@ struct PD_INFER_DECL AnalysisConfig {
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed.
/// V7.2.1 is needed.
///
///
void
Enable
TensorRtOSS
();
void
Enable
Varseqlen
();
///
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
/// \brief A boolean state telling whether to use the TensorRT OSS.
///
///
/// \return bool Whether to use the TensorRT OSS.
/// \return bool Whether to use the TensorRT OSS.
///
///
bool
tensorrt_
oss_enabled
()
{
return
trt_use_oss
_
;
}
bool
tensorrt_
varseqlen_enabled
()
{
return
trt_use_varseqlen
_
;
}
///
///
/// \brief Enable TensorRT DLA
/// \brief Enable TensorRT DLA
...
@@ -954,8 +954,10 @@ struct PD_INFER_DECL AnalysisConfig {
...
@@ -954,8 +954,10 @@ struct PD_INFER_DECL AnalysisConfig {
Precision
tensorrt_precision_mode_
{
Precision
::
kFloat32
};
Precision
tensorrt_precision_mode_
{
Precision
::
kFloat32
};
bool
trt_use_static_engine_
{
false
};
bool
trt_use_static_engine_
{
false
};
bool
trt_use_calib_mode_
{
true
};
bool
trt_use_calib_mode_
{
true
};
bool
trt_use_
oss
_
{
false
};
bool
trt_use_
varseqlen
_
{
false
};
bool
trt_with_interleaved_
{
false
};
bool
trt_with_interleaved_
{
false
};
std
::
string
tensorrt_transformer_posid_
{
""
};
std
::
string
tensorrt_transformer_maskid_
{
""
};
bool
trt_use_dla_
{
false
};
bool
trt_use_dla_
{
false
};
int
trt_dla_core_
{
0
};
int
trt_dla_core_
{
0
};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape_
{};
...
...
paddle/fluid/inference/api/paddle_api.h
浏览文件 @
2810dfea
...
@@ -435,6 +435,12 @@ class PD_INFER_DECL InternalUtils {
...
@@ -435,6 +435,12 @@ class PD_INFER_DECL InternalUtils {
static
void
UpdateConfigInterleaved
(
paddle_infer
::
Config
*
c
,
static
void
UpdateConfigInterleaved
(
paddle_infer
::
Config
*
c
,
bool
with_interleaved
);
bool
with_interleaved
);
static
void
SetTransformerPosid
(
paddle_infer
::
Config
*
c
,
const
std
::
string
&
tensorrt_transformer_posid
);
static
void
SetTransformerMaskid
(
paddle_infer
::
Config
*
c
,
const
std
::
string
&
tensorrt_transformer_maskid
);
static
void
SyncStream
(
paddle_infer
::
Predictor
*
pred
);
static
void
SyncStream
(
paddle_infer
::
Predictor
*
pred
);
static
void
SyncStream
(
cudaStream_t
stream
);
static
void
SyncStream
(
cudaStream_t
stream
);
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
2810dfea
...
@@ -94,25 +94,25 @@ const std::vector<std::string> kTRTSubgraphPasses({
...
@@ -94,25 +94,25 @@ const std::vector<std::string> kTRTSubgraphPasses({
"add_support_int8_pass"
,
//
"add_support_int8_pass"
,
//
// "fc_fuse_pass", //
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass"
,
//
"simplify_with_basic_ops_pass"
,
//
"
embedding_eltwise_layernorm_fuse_pass"
,
//
"
trt_embedding_eltwise_layernorm_fuse_pass"
,
//
"preln_embedding_eltwise_layernorm_fuse_pass"
,
//
"preln_embedding_eltwise_layernorm_fuse_pass"
,
//
"
multihead_matmul_fuse_pass_v2"
,
//
"
trt_multihead_matmul_fuse_pass_v2"
,
//
"
multihead_matmul_fuse_pass_v3"
,
//
"
trt_multihead_matmul_fuse_pass_v3"
,
//
"
skip_layernorm_fuse_pass"
,
//
"
trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
// "set_transformer_input_convert_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
"trt_reshape2_matmul_fuse_pass"
,
//
"trt_reshape2_matmul_fuse_pass"
,
//
"trt_flatten2_matmul_fuse_pass"
,
//
"trt_flatten2_matmul_fuse_pass"
,
//
"trt_map_matmul_v2_to_mul_pass"
,
//
"trt_map_matmul_v2_to_mul_pass"
,
//
"trt_map_matmul_v2_to_matmul_pass"
,
//
"trt_map_matmul_v2_to_matmul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
// "remove_padding_recover_padding_pass",
//
"remove_padding_recover_padding_pass"
,
//
// "delete_remove_padding_recover_padding_pass",
//
"delete_remove_padding_recover_padding_pass"
,
//
// "yolo_box_fuse_pass", //
// "yolo_box_fuse_pass", //
"tensorrt_subgraph_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
...
...
paddle/fluid/inference/capi_exp/pd_config.cc
浏览文件 @
2810dfea
...
@@ -303,13 +303,13 @@ void PD_ConfigDisableTensorRtOPs(__pd_keep PD_Config* pd_config, size_t ops_num,
...
@@ -303,13 +303,13 @@ void PD_ConfigDisableTensorRtOPs(__pd_keep PD_Config* pd_config, size_t ops_num,
config
->
Exp_DisableTensorRtOPs
(
ops_list
);
config
->
Exp_DisableTensorRtOPs
(
ops_list
);
}
}
void
PD_ConfigEnable
TensorRtOSS
(
__pd_keep
PD_Config
*
pd_config
)
{
void
PD_ConfigEnable
Varseqlen
(
__pd_keep
PD_Config
*
pd_config
)
{
CHECK_AND_CONVERT_PD_CONFIG
;
CHECK_AND_CONVERT_PD_CONFIG
;
config
->
Enable
TensorRtOSS
();
config
->
Enable
Varseqlen
();
}
}
PD_Bool
PD_ConfigTensorRtOssEnabled
(
__pd_keep
PD_Config
*
pd_config
)
{
PD_Bool
PD_ConfigTensorRtOssEnabled
(
__pd_keep
PD_Config
*
pd_config
)
{
CHECK_AND_CONVERT_PD_CONFIG
;
CHECK_AND_CONVERT_PD_CONFIG
;
return
config
->
tensorrt_
oss
_enabled
();
return
config
->
tensorrt_
varseqlen
_enabled
();
}
}
void
PD_ConfigEnableTensorRtDla
(
__pd_keep
PD_Config
*
pd_config
,
void
PD_ConfigEnableTensorRtDla
(
__pd_keep
PD_Config
*
pd_config
,
...
...
paddle/fluid/inference/capi_exp/pd_config.h
浏览文件 @
2810dfea
...
@@ -432,7 +432,7 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigDisableTensorRtOPs(
...
@@ -432,7 +432,7 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigDisableTensorRtOPs(
///
///
/// \param[in] pd_onfig config
/// \param[in] pd_onfig config
///
///
PADDLE_CAPI_EXPORT
extern
void
PD_ConfigEnable
TensorRtOSS
(
PADDLE_CAPI_EXPORT
extern
void
PD_ConfigEnable
Varseqlen
(
__pd_keep
PD_Config
*
pd_config
);
__pd_keep
PD_Config
*
pd_config
);
///
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
/// \brief A boolean state telling whether to use the TensorRT OSS.
...
...
paddle/fluid/inference/goapi/config.go
浏览文件 @
2810dfea
...
@@ -500,8 +500,8 @@ func (config *Config) DisableTensorRtOPs(ops []string) {
...
@@ -500,8 +500,8 @@ func (config *Config) DisableTensorRtOPs(ops []string) {
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed.
/// V7.2.1 is needed.
///
///
func
(
config
*
Config
)
Enable
TensorRtOSS
()
{
func
(
config
*
Config
)
Enable
Varseqlen
()
{
C
.
PD_ConfigEnable
TensorRtOSS
(
config
.
c
)
C
.
PD_ConfigEnable
Varseqlen
(
config
.
c
)
}
}
///
///
...
...
paddle/fluid/inference/goapi/config_test.go
浏览文件 @
2810dfea
...
@@ -54,7 +54,7 @@ func TestNewConfig(t *testing.T) {
...
@@ -54,7 +54,7 @@ func TestNewConfig(t *testing.T) {
}
}
config
.
SetTRTDynamicShapeInfo
(
minInputShape
,
maxInputShape
,
optInputShape
,
false
)
config
.
SetTRTDynamicShapeInfo
(
minInputShape
,
maxInputShape
,
optInputShape
,
false
)
config
.
Enable
TensorRtOSS
()
config
.
Enable
Varseqlen
()
t
.
Logf
(
"TensorrtOssEnabled:%+v"
,
config
.
TensorrtOssEnabled
())
t
.
Logf
(
"TensorrtOssEnabled:%+v"
,
config
.
TensorrtOssEnabled
())
config
.
EnableTensorRtDLA
(
0
)
config
.
EnableTensorRtDLA
(
0
)
...
@@ -138,4 +138,4 @@ func TestONNXRuntime(t *testing.T) {
...
@@ -138,4 +138,4 @@ func TestONNXRuntime(t *testing.T) {
config
.
SetCpuMathLibraryNumThreads
(
4
)
config
.
SetCpuMathLibraryNumThreads
(
4
)
t
.
Logf
(
"CpuMathLibraryNumThreads:%+v"
,
config
.
CpuMathLibraryNumThreads
())
t
.
Logf
(
"CpuMathLibraryNumThreads:%+v"
,
config
.
CpuMathLibraryNumThreads
())
}
}
\ No newline at end of file
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
2810dfea
...
@@ -56,7 +56,11 @@ nv_library(tensorrt_converter
...
@@ -56,7 +56,11 @@ nv_library(tensorrt_converter
strided_slice_op.cc
strided_slice_op.cc
preln_skip_layernorm.cc
preln_skip_layernorm.cc
roll_op.cc
roll_op.cc
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
paddle_framework
${
GLOB_OPERATOR_DEPS
}
tensorrt_engine tensorrt_converter
)
paddle_framework
${
GLOB_OPERATOR_DEPS
}
tensorrt_engine tensorrt_converter
)
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
2810dfea
...
@@ -30,23 +30,28 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -30,23 +30,28 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(6000)
VLOG
(
4
)
<<
"convert fluid EmbEltwiseLayerNorm op to tensorrt layer"
;
VLOG
(
4
)
<<
"convert fluid EmbEltwiseLayerNorm op to tensorrt layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
word_id_name
=
op_desc
.
Input
(
"WordId"
).
front
();
auto
word_id_name
=
op_desc
.
Input
(
"WordId"
).
front
();
auto
pos_id_name
=
op_desc
.
Input
(
"PosId"
).
front
();
auto
pos_id_name
=
engine_
->
tensorrt_transformer_posid
();
engine_
->
Set
(
"ernie_pos_name"
,
new
std
::
string
(
pos_id_name
));
engine_
->
Set
(
"ernie_pos_name"
,
new
std
::
string
(
pos_id_name
));
auto
sent_id_name
=
op_desc
.
Input
(
"SentId"
).
front
();
auto
sent_id_name
=
op_desc
.
Input
(
"SentId"
).
front
();
auto
mask_id_name
=
engine_
->
tensorrt_transformer_maskid
();
auto
word_emb_name
=
op_desc
.
Input
(
"WordEmbedding"
).
front
();
auto
word_emb_name
=
op_desc
.
Input
(
"WordEmbedding"
).
front
();
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
pos_emb_name
=
op_desc
.
Input
(
"PosEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
auto
sent_emb_name
=
op_desc
.
Input
(
"SentEmbedding"
).
front
();
std
::
vector
<
std
::
string
>
id_names
;
std
::
vector
<
std
::
string
>
id_names
;
std
::
vector
<
std
::
string
>
emb_names
;
std
::
vector
<
std
::
string
>
emb_names
;
bool
flag_varseqlen
=
engine_
->
use_varseqlen
()
&&
pos_id_name
!=
""
&&
mask_id_name
!=
""
;
if
(
engine_
->
use_oss
())
{
if
(
flag_varseqlen
)
{
engine_
->
SetITensor
(
"word_id"
,
engine_
->
GetITensor
(
word_id_name
));
engine_
->
SetITensor
(
"pos_id"
,
engine_
->
GetITensor
(
pos_id_name
));
engine_
->
SetITensor
(
"mask_id"
,
engine_
->
GetITensor
(
mask_id_name
));
id_names
=
id_names
=
std
::
vector
<
std
::
string
>
{
word_id_name
,
pos_id_name
,
sent_id_name
};
std
::
vector
<
std
::
string
>
{
word_id_name
,
pos_id_name
,
sent_id_name
};
emb_names
=
emb_names
=
...
@@ -106,7 +111,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -106,7 +111,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
if
(
engine_
->
use_oss
()
)
{
if
(
flag_varseqlen
)
{
int
output_fp16
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
1
:
0
);
int
output_fp16
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
1
:
0
);
if
(
enable_int8
)
{
if
(
enable_int8
)
{
output_fp16
=
1
;
output_fp16
=
1
;
...
@@ -121,7 +126,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -121,7 +126,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
output_fp16
,
1
,
output_fp16
,
1
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Only Precision::KHalf(fp16) is supported when infering "
"Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.Enable
TensorRtOSS
(). "
"ernie(bert) model with config.Enable
Varseqlen
(). "
"But Precision::KFloat32 is setted."
));
"But Precision::KFloat32 is setted."
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
{
"bert_embeddings_layernorm_beta"
,
bias
,
...
@@ -159,8 +164,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -159,8 +164,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs
.
emplace_back
(
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
pos_id_name
));
// cu_seqlens,
engine_
->
GetITensor
(
pos_id_name
));
// cu_seqlens,
// eval_placeholder_2
// eval_placeholder_2
auto
max_seqlen_tensor
=
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
mask_id_name
);
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
*
shuffle_layer
=
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
max_seqlen_tensor
);
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
max_seqlen_tensor
);
nvinfer1
::
Dims
shape_dim
;
nvinfer1
::
Dims
shape_dim
;
...
@@ -193,8 +197,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -193,8 +197,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
1
),
out_scale
);
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
1
),
out_scale
);
}
}
if
(
engine_
->
with_interleaved
())
{
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
VLOG
(
4
)
<<
"fused emb_eltwise_layernorm op: use_varseqlen and "
<<
"fused emb_eltwise_layernorm op: use_oss and
with_interleaved"
;
"
with_interleaved"
;
if
(
!
enable_int8
)
{
if
(
!
enable_int8
)
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
...
@@ -229,12 +233,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -229,12 +233,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
},
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
},
test_mode
);
test_mode
);
}
}
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
#endif
}
}
};
};
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
2810dfea
...
@@ -250,8 +250,7 @@ class FcOpConverter : public OpConverter {
...
@@ -250,8 +250,7 @@ class FcOpConverter : public OpConverter {
}
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
// not add Shuffle layer in ernie's multihead.
if
(
engine_
->
use_oss
()
&&
engine_
->
with_ernie
()
&&
x_dim
.
nbDims
==
4
&&
if
(
x_dim
.
nbDims
==
4
&&
x_num_col_dims
==
1
)
{
x_dim
.
d
[
3
]
==
1
&&
x_num_col_dims
==
2
)
{
if
(
enable_int8
||
support_int8
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv1x1 layer
// add conv1x1 layer
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
2810dfea
...
@@ -76,12 +76,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -76,12 +76,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
bool
flag_varseqlen
=
engine_
->
use_varseqlen
()
&&
engine_
->
tensorrt_transformer_posid
()
!=
""
&&
engine_
->
tensorrt_transformer_maskid
()
!=
""
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
use_oss
()
)
{
if
(
flag_varseqlen
)
{
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kFloat32
)
{
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kFloat32
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use use_
oss
must be int8 or half, not float32."
));
"use use_
varseqlen
must be int8 or half, not float32."
));
}
}
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
void
*>
(
weight_data
),
...
@@ -90,7 +92,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -90,7 +92,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast
<
void
*>
(
bias_data
),
static_cast
<
void
*>
(
bias_data
),
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
if
(
engine_
->
with_interleaved
())
{
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused multihead_matmul op: use_oss and with_interleaved"
;
VLOG
(
4
)
<<
"fused multihead_matmul op: use_varseqlen and "
"with_interleaved"
;
if
(
!
op_desc
.
HasAttr
(
"Input_scale"
))
{
if
(
!
op_desc
.
HasAttr
(
"Input_scale"
))
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
...
@@ -233,9 +236,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -233,9 +236,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"dp_probs"
))
/
127.0
;
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"dp_probs"
))
/
127.0
;
}
}
}
}
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"2"
);
"CustomQKVToContextPluginDynamic"
,
"2"
);
assert
(
creator
!=
nullptr
);
assert
(
creator
!=
nullptr
);
...
@@ -272,18 +272,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -272,18 +272,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"qkv_plugin_mask"
));
if
(
engine_
->
Has
(
"ernie_pos_name"
))
{
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"pos_id"
));
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
Get
<
std
::
string
>
(
"ernie_pos_name"
)));
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
"mask_id"
);
}
else
{
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
2
)
->
getName
()));
// cu_seqlens, eval_placeholder_2
}
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
2810dfea
...
@@ -32,7 +32,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -32,7 +32,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(7000)
#if IS_TRT_VERSION_GE(7000)
VLOG
(
4
)
<<
"convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"
;
VLOG
(
4
)
<<
"convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"
;
if
(
!
(
engine_
->
use_
oss
()
&&
engine_
->
with_interleaved
()))
{
if
(
!
(
engine_
->
use_
varseqlen
()
&&
engine_
->
with_interleaved
()))
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"PrelnErnie: If you want to use oss, must be with interleaved"
));
"PrelnErnie: If you want to use oss, must be with interleaved"
));
}
}
...
...
paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc
浏览文件 @
2810dfea
...
@@ -24,7 +24,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
...
@@ -24,7 +24,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(7000)
#if IS_TRT_VERSION_GE(7000)
VLOG
(
4
)
<<
"convert fused preln_skip_layernorm op to tensorrt layer"
;
VLOG
(
4
)
<<
"convert fused preln_skip_layernorm op to tensorrt layer"
;
if
(
!
(
engine_
->
use_
oss
()
&&
engine_
->
with_interleaved
()))
{
if
(
!
(
engine_
->
use_
varseqlen
()
&&
engine_
->
with_interleaved
()))
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"PrelnErnie: If you want to use oss, must be with interleaved"
));
"PrelnErnie: If you want to use oss, must be with interleaved"
));
}
}
...
@@ -60,7 +60,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
...
@@ -60,7 +60,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
VLOG
(
4
)
<<
"fused preln_skip_layernorm op: use_oss and with_interleaved"
;
VLOG
(
4
)
<<
"fused preln_skip_layernorm op: use_varseqlen and with_interleaved"
;
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"4"
);
"CustomSkipLayerNormPluginDynamic"
,
"4"
);
...
...
paddle/fluid/inference/tensorrt/convert/recover_padding_op.cc
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Recover padding of transformer'input.
*/
class
RecoverPadding
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"Recover padding of transformer'output: VarSeqlen -> Padding."
;
if
(
!
engine_
->
with_dynamic_shape
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"recover_padding_op: If you want to use transformer, must "
"be with dynamic shape"
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
/*
auto x_var_name = op_desc.Input(InputNames()).front();
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
*/
auto
input_name
=
op_desc
.
Input
(
"Input"
).
front
();
std
::
cout
<<
"input_name: "
<<
input_name
<<
std
::
endl
;
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
input_name
));
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
"pos_id"
));
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
"mask_id"
));
int
input_num
=
3
;
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
plugin
::
RecoverPaddingPlugin
*
plugin
=
new
plugin
::
RecoverPaddingPlugin
();
nvinfer1
::
ILayer
*
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
input_num
,
plugin
);
RreplenishLayerAndOutput
(
layer
,
"recover_padding"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
recover_padding
,
RecoverPadding
);
paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Remove padding of transformer'input.
*/
class
RemovePadding
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"Remove padding of transformer'input: Padding -> VarSeqlen"
;
if
(
!
engine_
->
with_dynamic_shape
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"remove_padding_op: If you want to use transformer, must "
"be with dynamic shape"
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
input_name
=
op_desc
.
Input
(
"Input"
).
front
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
input_name
));
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
"pos_id"
));
plugin_inputs
.
push_back
(
engine_
->
GetITensor
(
"word_id"
));
size_t
input_num
=
plugin_inputs
.
size
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
plugin
::
RemovePaddingPlugin
*
plugin
=
new
plugin
::
RemovePaddingPlugin
();
nvinfer1
::
ILayer
*
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
input_num
,
plugin
);
RreplenishLayerAndOutput
(
layer
,
"remove_padding_op"
,
{
output_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
remove_padding
,
RemovePadding
);
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
2810dfea
...
@@ -52,10 +52,13 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -52,10 +52,13 @@ class SkipLayerNormOpConverter : public OpConverter {
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
bool
flag_varseqlen
=
engine_
->
use_varseqlen
()
&&
if
(
engine_
->
use_oss
())
{
engine_
->
tensorrt_transformer_posid
()
!=
""
&&
engine_
->
tensorrt_transformer_maskid
()
!=
""
;
if
(
flag_varseqlen
)
{
if
(
engine_
->
with_interleaved
())
{
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused skip_layernorm op: use_oss and with_interleaved"
;
VLOG
(
4
)
<<
"fused skip_layernorm op: use_varseqlen and with_interleaved"
;
if
(
!
enable_int8
)
{
if
(
!
enable_int8
)
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
2810dfea
...
@@ -14,7 +14,6 @@ limitations under the License. */
...
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -74,47 +73,12 @@ class SliceOpConverter : public OpConverter {
...
@@ -74,47 +73,12 @@ class SliceOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
use_oss
()
&&
engine_
->
with_ernie
()
&&
bool
with_fp16
=
input_dims
.
nbDims
==
4
)
{
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
if
(
engine_
->
with_interleaved
())
{
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
auto
*
shuffler_slice
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
nvinfer1
::
Permutation
transpose_embed
{
2
,
1
,
0
,
3
};
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
shuffler_slice
->
setSecondTranspose
(
transpose_embed
);
engine_
->
SetTensorDynamicRange
(
shuffler_slice
->
getOutput
(
0
),
out_scale
);
shuffler_slice
->
setName
(
(
"SpecialSlice_interleaved: transpose: (Output: "
+
output_name
+
")"
)
.
c_str
());
plugin_inputs
.
emplace_back
(
shuffler_slice
->
getOutput
(
0
));
}
else
{
plugin_inputs
.
emplace_back
(
input
);
}
std
::
string
pos_name
;
if
(
engine_
->
Has
(
"ernie_pos_name"
))
{
pos_name
=
engine_
->
Get
<
std
::
string
>
(
"ernie_pos_name"
);
}
else
{
// hard code for compatibility
pos_name
=
engine_
->
network
()
->
getInput
(
2
)
->
getName
();
}
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
pos_name
));
// cu_seqlens, eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin
::
SpecialSlicePluginDynamic
*
plugin
=
new
plugin
::
SpecialSlicePluginDynamic
();
layer
=
engine_
->
AddDynamicPlugin
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
plugin
);
}
else
{
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
}
}
else
{
}
else
{
bool
with_fp16
=
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
...
...
paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Convert Transformer Input(pos_id, max_seqlen).
*/
class
TransformerInputConvert
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"Convert Transformer Input(pos_id, max_seqlen), use "
"transformer_input_convert_plugin"
;
if
(
!
engine_
->
with_dynamic_shape
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"transformer_input_convert_op: If you want to use transformer, must "
"be with dynamic shape"
));
}
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
input_name
=
op_desc
.
Input
(
"Input"
).
front
();
auto
*
input
=
engine_
->
GetITensor
(
input_name
);
int
input_num
=
op_desc
.
Input
(
"Input"
).
size
();
// tensorrt_subgraph_pass will rename tensor
// auto pos_id_name = op_desc.Output("PosId").front();
// auto max_seqlen_name = op_desc.Output("MaxSeqlen").front();
auto
pos_id_name
=
"pos_id_tensor"
;
auto
max_seqlen_name
=
"max_seqlen_tensor"
;
plugin
::
TransformerInputConvertPlugin
*
plugin
=
new
plugin
::
TransformerInputConvertPlugin
();
nvinfer1
::
ILayer
*
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
input_num
,
plugin
);
RreplenishLayerAndOutput
(
layer
,
"transformer_input_convert"
,
{
pos_id_name
,
max_seqlen_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
transformer_input_convert
,
TransformerInputConvert
);
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
2810dfea
...
@@ -410,14 +410,19 @@ class TensorRTEngine {
...
@@ -410,14 +410,19 @@ class TensorRTEngine {
suffix_counter
+=
1
;
suffix_counter
+=
1
;
}
}
void
SetUseOSS
(
bool
use_
oss
)
{
use_oss_
=
use_oss
;
}
void
SetUseOSS
(
bool
use_
varseqlen
)
{
use_varseqlen_
=
use_varseqlen
;
}
void
SetUseDLA
(
bool
use_dla
)
{
use_dla_
=
use_dla
;
}
void
SetUseDLA
(
bool
use_dla
)
{
use_dla_
=
use_dla
;
}
void
SetDLACore
(
int
dla_core
)
{
dla_core_
=
dla_core
;
}
void
SetDLACore
(
int
dla_core
)
{
dla_core_
=
dla_core
;
}
void
SetWithErnie
(
bool
with_ernie
)
{
with_ernie_
=
with_ernie
;
}
void
SetWithErnie
(
bool
with_ernie
)
{
with_ernie_
=
with_ernie
;
}
void
SetWithInterleaved
(
bool
with_interleaved
)
{
void
SetWithInterleaved
(
bool
with_interleaved
)
{
with_interleaved_
=
with_interleaved
;
with_interleaved_
=
with_interleaved
;
}
}
void
SetTransformerPosid
(
std
::
string
tensorrt_transformer_posid
)
{
tensorrt_transformer_posid_
=
tensorrt_transformer_posid
;
}
void
SetTransformerMaskid
(
std
::
string
tensorrt_transformer_maskid
)
{
tensorrt_transformer_maskid_
=
tensorrt_transformer_maskid
;
}
void
ClearWeights
()
{
void
ClearWeights
()
{
for
(
auto
&
weight_pair
:
weight_map
)
{
for
(
auto
&
weight_pair
:
weight_map
)
{
weight_pair
.
second
.
reset
(
nullptr
);
weight_pair
.
second
.
reset
(
nullptr
);
...
@@ -488,9 +493,15 @@ class TensorRTEngine {
...
@@ -488,9 +493,15 @@ class TensorRTEngine {
return
ret
;
return
ret
;
}
}
bool
use_
oss
()
{
return
use_oss
_
;
}
bool
use_
varseqlen
()
{
return
use_varseqlen
_
;
}
bool
with_ernie
()
{
return
with_ernie_
;
}
bool
with_ernie
()
{
return
with_ernie_
;
}
bool
with_interleaved
()
{
return
with_interleaved_
;
}
bool
with_interleaved
()
{
return
with_interleaved_
;
}
std
::
string
tensorrt_transformer_posid
()
{
return
tensorrt_transformer_posid_
;
}
std
::
string
tensorrt_transformer_maskid
()
{
return
tensorrt_transformer_maskid_
;
}
bool
disable_trt_plugin_fp16
()
{
return
disable_trt_plugin_fp16_
;
}
bool
disable_trt_plugin_fp16
()
{
return
disable_trt_plugin_fp16_
;
}
bool
with_dynamic_shape
()
{
return
with_dynamic_shape_
;
}
bool
with_dynamic_shape
()
{
return
with_dynamic_shape_
;
}
AnalysisConfig
::
Precision
precision
()
{
return
precision_
;
}
AnalysisConfig
::
Precision
precision
()
{
return
precision_
;
}
...
@@ -612,11 +623,13 @@ class TensorRTEngine {
...
@@ -612,11 +623,13 @@ class TensorRTEngine {
ShapeMapType
max_input_shape_
;
ShapeMapType
max_input_shape_
;
ShapeMapType
optim_input_shape_
;
ShapeMapType
optim_input_shape_
;
bool
disable_trt_plugin_fp16_
{
false
};
bool
disable_trt_plugin_fp16_
{
false
};
bool
use_
oss
_
{
false
};
bool
use_
varseqlen
_
{
false
};
bool
use_dla_
{
false
};
bool
use_dla_
{
false
};
int
dla_core_
{
0
};
int
dla_core_
{
0
};
bool
with_ernie_
{
false
};
bool
with_ernie_
{
false
};
bool
with_interleaved_
{
false
};
bool
with_interleaved_
{
false
};
std
::
string
tensorrt_transformer_posid_
;
std
::
string
tensorrt_transformer_maskid_
;
nvinfer1
::
ILogger
&
logger_
;
nvinfer1
::
ILogger
&
logger_
;
// max data size for the buffers.
// max data size for the buffers.
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
2810dfea
...
@@ -125,7 +125,10 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -125,7 +125,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"strided_slice"
,
"strided_slice"
,
"fused_preln_embedding_eltwise_layernorm"
,
"fused_preln_embedding_eltwise_layernorm"
,
"roll"
,
"roll"
,
"preln_skip_layernorm"
};
"preln_skip_layernorm"
,
"transformer_input_convert"
,
"recover_padding"
,
"remove_padding"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
"mul"
,
"matmul"
,
"matmul"
,
...
@@ -194,7 +197,10 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -194,7 +197,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"fused_preln_embedding_eltwise_layernorm"
,
"fused_preln_embedding_eltwise_layernorm"
,
"preln_skip_layernorm"
,
"preln_skip_layernorm"
,
"roll"
,
"roll"
,
"multiclass_nms3"
};
"multiclass_nms3"
,
"transformer_input_convert"
,
"recover_padding"
,
"remove_padding"
};
};
};
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
2810dfea
...
@@ -4,7 +4,7 @@ nv_library(tensorrt_plugin
...
@@ -4,7 +4,7 @@ nv_library(tensorrt_plugin
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
special_slice_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
anchor_generator_op_plugin.cu
anchor_generator_op_plugin.cu
yolo_box_op_plugin.cu
yolo_box_op_plugin.cu
yolo_box_head_op_plugin.cu
yolo_box_head_op_plugin.cu
...
@@ -14,6 +14,9 @@ nv_library(tensorrt_plugin
...
@@ -14,6 +14,9 @@ nv_library(tensorrt_plugin
pool3d_op_plugin.cu
pool3d_op_plugin.cu
deformable_conv_op_plugin.cu
deformable_conv_op_plugin.cu
matmul_op_int8_plugin.cu
matmul_op_int8_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
nv_test
(
test_split_plugin SRCS test_split_plugin.cc DEPS
nv_test
(
test_split_plugin SRCS test_split_plugin.cc DEPS
...
...
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
__global__
void
RecoverPaddingKernel
(
const
float
*
input0
,
const
int32_t
*
input1
,
float
*
output
)
{
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
if
(
blockIdx
.
y
<
seqence_length
)
{
output
[
word_id
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
input0
[(
input1
[
blockIdx
.
x
]
+
blockIdx
.
y
)
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
];
}
else
{
output
[
word_id
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
0
;
}
}
nvinfer1
::
DataType
RecoverPaddingPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
input_types
[
0
];
}
nvinfer1
::
DimsExprs
RecoverPaddingPlugin
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
output_dims
{};
output_dims
.
nbDims
=
3
;
const
auto
*
one
=
exprBuilder
.
constant
(
1
);
output_dims
.
d
[
0
]
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
inputs
[
1
].
d
[
0
],
*
one
);
output_dims
.
d
[
1
]
=
inputs
[
2
].
d
[
1
];
output_dims
.
d
[
2
]
=
inputs
[
0
].
d
[
1
];
return
output_dims
;
}
bool
RecoverPaddingPlugin
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"Must have 3 inputs, "
"but got %d input(s). "
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
getNbOutputs
(),
platform
::
errors
::
InvalidArgument
(
"Must have 1 output, "
"but got %d output(s). "
,
nbOutputs
));
if
(
pos
==
1
)
{
// PosId, MaxSeqlen
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
// nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
// nvinfer1::TensorFormat::kCHW32);
}
void
RecoverPaddingPlugin
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
void
RecoverPaddingPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
RecoverPaddingPlugin
::
detachFromContext
()
TRT_NOEXCEPT
{}
void
RecoverPaddingPlugin
::
terminate
()
TRT_NOEXCEPT
{}
int
RecoverPaddingPlugin
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
input0_desc
=
inputDesc
[
0
];
const
auto
input1_desc
=
inputDesc
[
1
];
const
auto
input2_desc
=
inputDesc
[
2
];
const
float
*
input0
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
const
int32_t
num_threads
=
256
;
const
dim3
num_blocks
(
input1_desc
.
dims
.
d
[
0
]
-
1
,
input2_desc
.
dims
.
d
[
1
],
input0_desc
.
dims
.
d
[
1
]
/
num_threads
);
// batchs, max sequnce length
// (mask_id.dims.d[1]),
// input.dims.d[1]/256
RecoverPaddingKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input0
,
input1
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <string>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
RecoverPaddingPlugin
:
public
DynamicPluginTensorRT
{
public:
RecoverPaddingPlugin
()
{}
RecoverPaddingPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
RecoverPaddingPlugin
*
ptr
=
new
RecoverPaddingPlugin
();
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"recover_padding_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
terminate
()
TRT_NOEXCEPT
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
override
;
void
detachFromContext
()
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
protected:
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{}
};
class
RecoverPaddingPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
RecoverPaddingPluginCreator
()
{}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"recover_padding_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
plugin_field
)
TRT_NOEXCEPT
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
void
const
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
RecoverPaddingPlugin
*
obj
=
new
RecoverPaddingPlugin
(
serial_data
,
serial_length
);
obj
->
setPluginNamespace
(
name
);
return
obj
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
};
REGISTER_TRT_PLUGIN_V2
(
RecoverPaddingPluginCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
__global__
void
RemovePaddingKernel
(
const
float
*
input0
,
const
int32_t
*
input1
,
float
*
output
)
{
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
if
(
blockIdx
.
y
<
seqence_length
)
{
output
[(
input1
[
blockIdx
.
x
]
+
blockIdx
.
y
)
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
input0
[
word_id
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
];
}
}
nvinfer1
::
DataType
RemovePaddingPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
input_types
[
0
];
}
nvinfer1
::
DimsExprs
RemovePaddingPlugin
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
output_dims
{};
output_dims
.
nbDims
=
4
;
output_dims
.
d
[
0
]
=
inputs
[
2
].
d
[
0
];
output_dims
.
d
[
1
]
=
inputs
[
0
].
d
[
2
];
output_dims
.
d
[
2
]
=
exprBuilder
.
constant
(
1
);
output_dims
.
d
[
3
]
=
exprBuilder
.
constant
(
1
);
return
output_dims
;
}
bool
RemovePaddingPlugin
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"Must have 3 inputs, "
"but got %d input(s). "
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
getNbOutputs
(),
platform
::
errors
::
InvalidArgument
(
"Must have 1 output, "
"but got %d output(s). "
,
nbOutputs
));
if
(
pos
==
1
||
pos
==
2
)
{
// pos_id, work_id
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
// nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
// nvinfer1::TensorFormat::kCHW32);
}
void
RemovePaddingPlugin
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
void
RemovePaddingPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
RemovePaddingPlugin
::
detachFromContext
()
TRT_NOEXCEPT
{}
void
RemovePaddingPlugin
::
terminate
()
TRT_NOEXCEPT
{}
int
RemovePaddingPlugin
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
input_desc
=
inputDesc
[
0
];
const
float
*
input0
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
const
auto
input0_desc
=
inputDesc
[
0
];
const
int32_t
num_threads
=
256
;
const
dim3
num_blocks
(
input0_desc
.
dims
.
d
[
0
],
input0_desc
.
dims
.
d
[
1
],
input0_desc
.
dims
.
d
[
2
]
/
num_threads
);
// batchs, max sequnce length, input.dims.d[2]/256
RemovePaddingKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input0
,
input1
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/
special_slice
_plugin.h
→
paddle/fluid/inference/tensorrt/plugin/
remove_padding
_plugin.h
浏览文件 @
2810dfea
/
/ Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserved.
/
* Copyright (c) 2022
PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#pragma once
#pragma once
#include <stdio.h>
#include <cassert>
#include <cassert>
#include <string>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
RemovePaddingPlugin
:
public
DynamicPluginTensorRT
{
class
SpecialSlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
SpecialSlicePluginDynamic
();
RemovePaddingPlugin
()
{}
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
);
~
SpecialSlicePluginDynamic
();
RemovePaddingPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
;
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
RemovePaddingPlugin
*
ptr
=
new
RemovePaddingPlugin
();
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"remove_padding_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
terminate
()
TRT_NOEXCEPT
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
puts
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
;
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
override
;
void
detachFromContext
()
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
;
private:
protected:
int
axis_
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
0
;
}
int
num_stack_
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{}
};
};
class
SpecialSlicePluginDynamic
Creator
:
public
nvinfer1
::
IPluginCreator
{
class
RemovePaddingPlugin
Creator
:
public
nvinfer1
::
IPluginCreator
{
public:
public:
SpecialSlicePluginDynamicCreator
();
RemovePaddingPluginCreator
()
{}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
;
return
"remove_padding_plugin"
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
plugin_field
)
TRT_NOEXCEPT
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
const
char
*
name
,
void
const
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
;
size_t
serial_length
)
TRT_NOEXCEPT
override
{
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
;
RemovePaddingPlugin
*
obj
=
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
;
new
RemovePaddingPlugin
(
serial_data
,
serial_length
);
obj
->
setPluginNamespace
(
name
);
return
obj
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
plugin_namespace_
.
c_str
();
}
private:
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
};
REGISTER_TRT_PLUGIN_V2
(
SpecialSlicePluginDynamicCreator
);
REGISTER_TRT_PLUGIN_V2
(
RemovePaddingPluginCreator
);
#endif
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
已删除
100644 → 0
浏览文件 @
0cb9dae5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
()
{}
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
SpecialSlicePluginDynamic
::~
SpecialSlicePluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
SpecialSlicePluginDynamic
::
clone
()
const
TRT_NOEXCEPT
{
return
new
SpecialSlicePluginDynamic
();
}
const
char
*
SpecialSlicePluginDynamic
::
getPluginType
()
const
TRT_NOEXCEPT
{
return
"special_slice_plugin"
;
}
int
SpecialSlicePluginDynamic
::
getNbOutputs
()
const
TRT_NOEXCEPT
{
return
1
;
}
int
SpecialSlicePluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
size_t
SpecialSlicePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
serialize_size
=
0
;
return
serialize_size
;
}
void
SpecialSlicePluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{}
nvinfer1
::
DimsExprs
SpecialSlicePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
output
.
nbDims
++
;
for
(
int
i
=
output
.
nbDims
-
1
;
i
>
1
;
i
--
)
{
output
.
d
[
i
]
=
inputs
[
0
].
d
[
i
-
1
];
}
auto
one
=
expr_builder
.
constant
(
1
);
output
.
d
[
1
]
=
one
;
output
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
inputs
[
1
].
d
[
0
],
*
one
);
// remove padding 1
output
.
nbDims
-=
2
;
return
output
;
}
void
SpecialSlicePluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
size_t
SpecialSlicePluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
{
return
0
;
}
void
SpecialSlicePluginDynamic
::
destroy
()
TRT_NOEXCEPT
{
delete
this
;
}
void
SpecialSlicePluginDynamic
::
terminate
()
TRT_NOEXCEPT
{}
bool
SpecialSlicePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
desc
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
if
(
pos
==
0
)
// slice tensor
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if
(
pos
==
1
)
// cu_seqlen
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1
::
DataType
SpecialSlicePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The index should be equal to 0"
));
return
input_types
[
0
];
}
template
<
typename
T
>
__global__
void
SpecialSliceKernel
(
const
T
*
slice_input
,
const
int32_t
*
cu_seqlens
,
T
*
output
)
{
const
int
hidden
=
blockDim
.
x
*
gridDim
.
x
;
const
int
hidden_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
batch_id
=
blockIdx
.
y
;
output
[
batch_id
*
hidden
+
hidden_id
]
=
slice_input
[
cu_seqlens
[
batch_id
]
*
hidden
+
hidden_id
];
}
int
SpecialSlicePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
// (sum(S), hidden, 1, 1)
auto
out_dims
=
output_desc
[
0
].
dims
;
// (batch, hidden, 1, 1)
PADDLE_ENFORCE_EQ
(
input_desc
[
0
].
type
,
nvinfer1
::
DataType
::
kHALF
,
platform
::
errors
::
InvalidArgument
(
"Type of input should be half."
));
const
int32_t
hidden
=
input_dims
.
d
[
1
];
PADDLE_ENFORCE_EQ
(
hidden
%
128
,
0
,
platform
::
errors
::
InvalidArgument
(
"hidden should be multiple of 128."
));
constexpr
int
num_threads
=
128
;
const
half
*
slice_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
cu_seqlens
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
const
int32_t
num_blocks_x
=
hidden
/
num_threads
;
const
int32_t
num_blocks_y
=
out_dims
.
d
[
0
];
// batchs
const
dim3
num_blocks
(
num_blocks_x
,
num_blocks_y
);
// blocks
SpecialSliceKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
slice_input
,
cu_seqlens
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
SpecialSlicePluginDynamicCreator
::
SpecialSlicePluginDynamicCreator
()
{}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginName
()
const
TRT_NOEXCEPT
{
return
"special_slice_plugin"
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginVersion
()
const
TRT_NOEXCEPT
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
SpecialSlicePluginDynamicCreator
::
getFieldNames
()
TRT_NOEXCEPT
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
{
return
new
SpecialSlicePluginDynamic
();
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
{
auto
plugin
=
new
SpecialSlicePluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
SpecialSlicePluginDynamicCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginNamespace
()
const
TRT_NOEXCEPT
{
return
plugin_namespace_
.
c_str
();
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
__global__
void
TransformerInputConvertKernel
(
const
int64_t
*
input
,
int32_t
*
output0
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
__shared__
int32_t
shared_data
;
if
(
threadIdx
.
x
==
static_cast
<
int
>
(
input
[
tid
]))
{
atomicAdd
(
&
shared_data
,
1
);
}
output0
[
0
]
=
0
;
output0
[
blockIdx
.
x
+
1
]
=
shared_data
;
__syncthreads
();
for
(
int
i
=
0
;
i
<
blockDim
.
x
;
++
i
)
{
output0
[
i
+
1
]
+=
output0
[
i
];
}
}
nvinfer1
::
DataType
TransformerInputConvertPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
nvinfer1
::
DataType
::
kINT32
;
}
nvinfer1
::
DimsExprs
TransformerInputConvertPlugin
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
output_dims
{};
output_dims
.
nbDims
=
1
;
if
(
outputIndex
==
0
)
{
// PosId
const
auto
*
one
=
exprBuilder
.
constant
(
1
);
output_dims
.
d
[
0
]
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
0
],
*
one
);
}
else
{
// MaxSeqlen
output_dims
.
d
[
0
]
=
inputs
[
0
].
d
[
1
];
}
return
output_dims
;
}
bool
TransformerInputConvertPlugin
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"Must have 1 inputs, "
"but got %d input(s). "
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
getNbOutputs
(),
platform
::
errors
::
InvalidArgument
(
"Must have 2 output, "
"but got %d output(s). "
,
nbOutputs
));
if
(
pos
==
0
)
{
// input
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
else
{
// output0, output1
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
void
TransformerInputConvertPlugin
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
detachFromContext
()
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
terminate
()
TRT_NOEXCEPT
{}
int
TransformerInputConvertPlugin
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
input_desc
=
inputDesc
[
0
];
const
int64_t
*
input
=
static_cast
<
const
int64_t
*>
(
inputs
[
0
]);
int32_t
*
output0
=
static_cast
<
int32_t
*>
(
outputs
[
0
]);
// PosId
// int32_t* output1 = static_cast<int32_t*>(outputs[1]); // MaxSeqlen
const
int32_t
num_blocks
=
input_desc
.
dims
.
d
[
0
];
// batchs
const
int32_t
num_threads
=
input_desc
.
dims
.
d
[
1
];
// max sequnce length
TransformerInputConvertKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input
,
output0
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h
0 → 100644
浏览文件 @
2810dfea
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <string>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
TransformerInputConvertPlugin
:
public
DynamicPluginTensorRT
{
public:
TransformerInputConvertPlugin
()
{}
TransformerInputConvertPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
TransformerInputConvertPlugin
*
ptr
=
new
TransformerInputConvertPlugin
();
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"transformer_input_convert_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
2
;
}
int
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
terminate
()
TRT_NOEXCEPT
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
override
;
void
detachFromContext
()
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
protected:
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{}
};
class
TransformerInputConvertPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
TransformerInputConvertPluginCreator
()
{}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"transformer_input_convert_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
plugin_field
)
TRT_NOEXCEPT
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
void
const
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
TransformerInputConvertPlugin
*
obj
=
new
TransformerInputConvertPlugin
(
serial_data
,
serial_length
);
obj
->
setPluginNamespace
(
name
);
return
obj
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
};
REGISTER_TRT_PLUGIN_V2
(
TransformerInputConvertPluginCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc
浏览文件 @
2810dfea
...
@@ -65,7 +65,7 @@ TEST(PD_Config, gpu_interface) {
...
@@ -65,7 +65,7 @@ TEST(PD_Config, gpu_interface) {
&
min_shape_ptr
,
&
max_shape_ptr
,
&
min_shape_ptr
,
&
max_shape_ptr
,
&
opt_shape_ptr
,
FALSE
);
&
opt_shape_ptr
,
FALSE
);
PD_ConfigDisableTensorRtOPs
(
config
,
1
,
&
ops_name
);
PD_ConfigDisableTensorRtOPs
(
config
,
1
,
&
ops_name
);
PD_ConfigEnable
TensorRtOSS
(
config
);
PD_ConfigEnable
Varseqlen
(
config
);
bool
oss_enabled
=
PD_ConfigTensorRtOssEnabled
(
config
);
bool
oss_enabled
=
PD_ConfigTensorRtOssEnabled
(
config
);
EXPECT_TRUE
(
oss_enabled
);
EXPECT_TRUE
(
oss_enabled
);
...
...
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
浏览文件 @
2810dfea
...
@@ -210,7 +210,11 @@ std::shared_ptr<paddle_infer::Predictor> InitPredictor() {
...
@@ -210,7 +210,11 @@ std::shared_ptr<paddle_infer::Predictor> InitPredictor() {
config
.
SetTRTDynamicShapeInfo
(
min_input_shape
,
max_input_shape
,
config
.
SetTRTDynamicShapeInfo
(
min_input_shape
,
max_input_shape
,
opt_input_shape
);
opt_input_shape
);
// erinie varlen must be used with oss
// erinie varlen must be used with oss
config
.
EnableTensorRtOSS
();
config
.
EnableVarseqlen
();
paddle_infer
::
experimental
::
InternalUtils
::
SetTransformerPosid
(
&
config
,
input_name2
);
paddle_infer
::
experimental
::
InternalUtils
::
SetTransformerMaskid
(
&
config
,
input_name3
);
return
paddle_infer
::
CreatePredictor
(
config
);
return
paddle_infer
::
CreatePredictor
(
config
);
}
}
...
...
paddle/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc
浏览文件 @
2810dfea
...
@@ -68,7 +68,7 @@ std::shared_ptr<Predictor> InitPredictor() {
...
@@ -68,7 +68,7 @@ std::shared_ptr<Predictor> InitPredictor() {
config
.
SetTRTDynamicShapeInfo
(
min_input_shape
,
max_input_shape
,
config
.
SetTRTDynamicShapeInfo
(
min_input_shape
,
max_input_shape
,
opt_input_shape
);
opt_input_shape
);
// erinie varlen must be used with oss
// erinie varlen must be used with oss
config
.
Enable
TensorRtOSS
();
config
.
Enable
Varseqlen
();
return
CreatePredictor
(
config
);
return
CreatePredictor
(
config
);
}
}
...
...
paddle/fluid/inference/utils/table_printer_tester.cc
浏览文件 @
2810dfea
...
@@ -43,7 +43,7 @@ TEST(table_printer, output) {
...
@@ -43,7 +43,7 @@ TEST(table_printer, output) {
table
.
InsertRow
({
"trt_precision"
,
"fp32"
});
table
.
InsertRow
({
"trt_precision"
,
"fp32"
});
table
.
InsertRow
({
"enable_dynamic_shape"
,
"true"
});
table
.
InsertRow
({
"enable_dynamic_shape"
,
"true"
});
table
.
InsertRow
({
"DisableTensorRtOPs"
,
"{}"
});
table
.
InsertRow
({
"DisableTensorRtOPs"
,
"{}"
});
table
.
InsertRow
({
"Enable
TensorRtOSS
"
,
"ON"
});
table
.
InsertRow
({
"Enable
Varseqlen
"
,
"ON"
});
table
.
InsertRow
({
"tensorrt_dla_enabled"
,
"ON"
});
table
.
InsertRow
({
"tensorrt_dla_enabled"
,
"ON"
});
table
.
InsetDivider
();
table
.
InsetDivider
();
...
...
paddle/fluid/pybind/inference_api.cc
浏览文件 @
2810dfea
...
@@ -657,8 +657,9 @@ void BindAnalysisConfig(py::module *m) {
...
@@ -657,8 +657,9 @@ void BindAnalysisConfig(py::module *m) {
py
::
arg
(
"disable_trt_plugin_fp16"
)
=
false
)
py
::
arg
(
"disable_trt_plugin_fp16"
)
=
false
)
.
def
(
"tensorrt_dynamic_shape_enabled"
,
.
def
(
"tensorrt_dynamic_shape_enabled"
,
&
AnalysisConfig
::
tensorrt_dynamic_shape_enabled
)
&
AnalysisConfig
::
tensorrt_dynamic_shape_enabled
)
.
def
(
"enable_tensorrt_oss"
,
&
AnalysisConfig
::
EnableTensorRtOSS
)
.
def
(
"enable_tensorrt_varseqlen"
,
&
AnalysisConfig
::
EnableVarseqlen
)
.
def
(
"tensorrt_oss_enabled"
,
&
AnalysisConfig
::
tensorrt_oss_enabled
)
.
def
(
"tensorrt_varseqlen_enabled"
,
&
AnalysisConfig
::
tensorrt_varseqlen_enabled
)
.
def
(
"collect_shape_range_info"
,
&
AnalysisConfig
::
CollectShapeRangeInfo
)
.
def
(
"collect_shape_range_info"
,
&
AnalysisConfig
::
CollectShapeRangeInfo
)
.
def
(
"shape_range_info_path"
,
&
AnalysisConfig
::
shape_range_info_path
)
.
def
(
"shape_range_info_path"
,
&
AnalysisConfig
::
shape_range_info_path
)
.
def
(
"shape_range_info_collected"
,
.
def
(
"shape_range_info_collected"
,
...
...
python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py
浏览文件 @
2810dfea
...
@@ -42,7 +42,7 @@ class InferencePassTest(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class InferencePassTest(unittest.TestCase):
self
.
enable_mkldnn
=
False
self
.
enable_mkldnn
=
False
self
.
enable_mkldnn_bfloat16
=
False
self
.
enable_mkldnn_bfloat16
=
False
self
.
enable_trt
=
False
self
.
enable_trt
=
False
self
.
enable_tensorrt_
oss
=
True
self
.
enable_tensorrt_
varseqlen
=
True
self
.
trt_parameters
=
None
self
.
trt_parameters
=
None
self
.
dynamic_shape_params
=
None
self
.
dynamic_shape_params
=
None
self
.
enable_lite
=
False
self
.
enable_lite
=
False
...
@@ -134,8 +134,8 @@ class InferencePassTest(unittest.TestCase):
...
@@ -134,8 +134,8 @@ class InferencePassTest(unittest.TestCase):
self
.
dynamic_shape_params
.
max_input_shape
,
self
.
dynamic_shape_params
.
max_input_shape
,
self
.
dynamic_shape_params
.
optim_input_shape
,
self
.
dynamic_shape_params
.
optim_input_shape
,
self
.
dynamic_shape_params
.
disable_trt_plugin_fp16
)
self
.
dynamic_shape_params
.
disable_trt_plugin_fp16
)
if
self
.
enable_tensorrt_
oss
:
if
self
.
enable_tensorrt_
varseqlen
:
config
.
enable_tensorrt_
oss
()
config
.
enable_tensorrt_
varseqlen
()
elif
use_mkldnn
:
elif
use_mkldnn
:
config
.
enable_mkldnn
()
config
.
enable_mkldnn
()
...
...
python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py
浏览文件 @
2810dfea
...
@@ -46,7 +46,7 @@ class QuantDequantTest(unittest.TestCase):
...
@@ -46,7 +46,7 @@ class QuantDequantTest(unittest.TestCase):
self
.
enable_mkldnn
=
False
self
.
enable_mkldnn
=
False
self
.
enable_mkldnn_bfloat16
=
False
self
.
enable_mkldnn_bfloat16
=
False
self
.
enable_trt
=
False
self
.
enable_trt
=
False
self
.
enable_tensorrt_
oss
=
True
self
.
enable_tensorrt_
varseqlen
=
True
self
.
trt_parameters
=
None
self
.
trt_parameters
=
None
self
.
dynamic_shape_params
=
None
self
.
dynamic_shape_params
=
None
self
.
enable_lite
=
False
self
.
enable_lite
=
False
...
@@ -184,8 +184,8 @@ class QuantDequantTest(unittest.TestCase):
...
@@ -184,8 +184,8 @@ class QuantDequantTest(unittest.TestCase):
self
.
dynamic_shape_params
.
max_input_shape
,
self
.
dynamic_shape_params
.
max_input_shape
,
self
.
dynamic_shape_params
.
optim_input_shape
,
self
.
dynamic_shape_params
.
optim_input_shape
,
self
.
dynamic_shape_params
.
disable_trt_plugin_fp16
)
self
.
dynamic_shape_params
.
disable_trt_plugin_fp16
)
if
self
.
enable_tensorrt_
oss
:
if
self
.
enable_tensorrt_
varseqlen
:
config
.
enable_tensorrt_
oss
()
config
.
enable_tensorrt_
varseqlen
()
elif
use_mkldnn
:
elif
use_mkldnn
:
config
.
enable_mkldnn
()
config
.
enable_mkldnn
()
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py
浏览文件 @
2810dfea
...
@@ -179,7 +179,7 @@ def multiclass_nms(bboxes,
...
@@ -179,7 +179,7 @@ def multiclass_nms(bboxes,
class
TensorRTMultiClassNMS3Test
(
InferencePassTest
):
class
TensorRTMultiClassNMS3Test
(
InferencePassTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
enable_trt
=
True
self
.
enable_trt
=
True
self
.
enable_tensorrt_
oss
=
True
self
.
enable_tensorrt_
varseqlen
=
True
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
serialize
=
False
self
.
serialize
=
False
self
.
bs
=
1
self
.
bs
=
1
...
@@ -291,8 +291,8 @@ class TensorRTMultiClassNMS3Test(InferencePassTest):
...
@@ -291,8 +291,8 @@ class TensorRTMultiClassNMS3Test(InferencePassTest):
self
.
background
=
7
self
.
background
=
7
self
.
run_test
()
self
.
run_test
()
def
test_disable_
oss
(
self
):
def
test_disable_
varseqlen
(
self
):
self
.
diable_tensorrt_
oss
=
False
self
.
diable_tensorrt_
varseqlen
=
False
self
.
run_test
()
self
.
run_test
()
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py
浏览文件 @
2810dfea
...
@@ -25,7 +25,7 @@ from paddle.fluid.core import AnalysisConfig
...
@@ -25,7 +25,7 @@ from paddle.fluid.core import AnalysisConfig
class
TensorRTMultiClassNMSTest
(
InferencePassTest
):
class
TensorRTMultiClassNMSTest
(
InferencePassTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
enable_trt
=
True
self
.
enable_trt
=
True
self
.
enable_tensorrt_
oss
=
True
self
.
enable_tensorrt_
varseqlen
=
True
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
serialize
=
False
self
.
serialize
=
False
self
.
bs
=
1
self
.
bs
=
1
...
@@ -135,8 +135,8 @@ class TensorRTMultiClassNMSTest(InferencePassTest):
...
@@ -135,8 +135,8 @@ class TensorRTMultiClassNMSTest(InferencePassTest):
self
.
background
=
7
self
.
background
=
7
self
.
run_test
()
self
.
run_test
()
def
test_disable_
oss
(
self
):
def
test_disable_
varseqlen
(
self
):
self
.
diable_tensorrt_
oss
=
False
self
.
diable_tensorrt_
varseqlen
=
False
self
.
run_test
()
self
.
run_test
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录