Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b2bb7ec9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b2bb7ec9
编写于
2月 09, 2023
作者:
W
Wang Bojun
提交者:
GitHub
2月 09, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[TRT] Transpose layernorm fusion with different input format (#50082)
* trans_layernorm
上级
b3f60f39
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
1509 addition
and
68 deletion
+1509
-68
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/trans_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trans_layernorm_fuse_pass.cc
+207
-0
paddle/fluid/framework/ir/trans_layernorm_fuse_pass.h
paddle/fluid/framework/ir/trans_layernorm_fuse_pass.h
+132
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+18
-14
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/trans_layernorm_op.cc
...le/fluid/inference/tensorrt/convert/trans_layernorm_op.cc
+90
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+9
-1
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
...d/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
+66
-53
paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu
...id/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu
+555
-0
paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.h
...uid/inference/tensorrt/plugin/trans_layernorm_op_plugin.h
+176
-0
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_trans_layernorm.py
...nittests/ir/inference/test_trt_convert_trans_layernorm.py
+249
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
b2bb7ec9
...
@@ -145,6 +145,7 @@ if(WITH_TENSORRT)
...
@@ -145,6 +145,7 @@ if(WITH_TENSORRT)
pass_library
(
elementwise_groupnorm_act_pass inference
)
pass_library
(
elementwise_groupnorm_act_pass inference
)
pass_library
(
preln_elementwise_groupnorm_act_pass inference
)
pass_library
(
preln_elementwise_groupnorm_act_pass inference
)
pass_library
(
groupnorm_act_pass inference
)
pass_library
(
groupnorm_act_pass inference
)
pass_library
(
trans_layernorm_fuse_pass inference
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
endif
()
endif
()
...
...
paddle/fluid/framework/ir/trans_layernorm_fuse_pass.cc
0 → 100644
浏览文件 @
b2bb7ec9
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/trans_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
TransLayernormPattern
:
public
PatternBase
{
TransLayernormPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"trans_layernorm"
)
{}
void
operator
()(
PDNode
*
x
);
PATTERN_DECL_NODE
(
transpose
);
PATTERN_DECL_NODE
(
transpose_output
);
PATTERN_DECL_NODE
(
reshape
);
PATTERN_DECL_NODE
(
reshape_output
);
PATTERN_DECL_NODE
(
layernorm
);
PATTERN_DECL_NODE
(
layernorm_scale
);
PATTERN_DECL_NODE
(
layernorm_bias
);
PATTERN_DECL_NODE
(
layernorm_output
);
};
void
TransLayernormPattern
::
operator
()(
PDNode
*
x
)
{
std
::
unordered_set
<
std
::
string
>
reshape_ops
{
"reshape2"
,
"flatten_contiguous_range"
};
auto
*
transpose
=
pattern
->
NewNode
(
transpose_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose_output
=
pattern
->
NewNode
(
transpose_output_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
assert_is_ops_input
(
reshape_ops
,
"X"
);
transpose
->
LinksFrom
({
x
}).
LinksTo
({
transpose_output
});
auto
*
reshape
=
pattern
->
NewNode
(
reshape_repr
())
->
assert_is_ops
(
reshape_ops
);
auto
*
reshape_output
=
pattern
->
NewNode
(
reshape_output_repr
())
->
assert_is_ops_output
(
reshape_ops
,
"Out"
)
->
assert_is_op_input
(
"layer_norm"
,
"X"
)
->
AsOutput
();
reshape
->
LinksFrom
({
transpose_output
}).
LinksTo
({
reshape_output
});
auto
*
layernorm
=
pattern
->
NewNode
(
layernorm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layernorm_scale
=
pattern
->
NewNode
(
layernorm_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
)
->
AsInput
();
auto
*
layernorm_bias
=
pattern
->
NewNode
(
layernorm_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
)
->
AsInput
();
auto
*
layernorm_output
=
pattern
->
NewNode
(
layernorm_output_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
AsOutput
();
layernorm
->
LinksFrom
({
reshape_output
,
layernorm_scale
,
layernorm_bias
})
.
LinksTo
({
layernorm_output
});
}
}
// namespace patterns
// this pass make a fusion as below:
//
// |
// transpose(axis= [0,2,3,1])
// |
// reshape(n,h*w,c)
// | |
// out layernorm(begin_norm_axis=2 or -1)
// |
// layernorm_out
//
// ->fuse to
//
// |
// trans_layernorm
// | |
// out layernorm_out
//
int
TransLayernormFusePass
::
ApplyConvTransLayernormPattern
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"trans_layernorm_fuse"
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
PDNode
*
x
=
nullptr
;
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"trans_layernorm_fuse/x"
)
->
AsInput
()
->
assert_var_not_persistable
()
->
assert_is_op_input
(
"transpose2"
,
"X"
);
patterns
::
TransLayernormPattern
fused_pattern
(
gpd
.
mutable_pattern
(),
"trans_layernorm_fuse"
);
fused_pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
subgraph
.
count
(
x
)
<=
0
)
{
LOG
(
WARNING
)
<<
"The subgraph is empty."
;
return
;
}
VLOG
(
4
)
<<
"handle transpose layernorm fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
transpose
,
transpose
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_output
,
transpose_output
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape
,
reshape
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_output
,
reshape_output
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layernorm
,
layernorm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layernorm_scale
,
layernorm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layernorm_bias
,
layernorm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layernorm_output
,
layernorm_output
,
fused_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"transpose layernorm pass in op compat failed."
;
return
;
}
// trans_layernorm is suit for nchw-to-nhwc transpose before layernorm
// check for it
std
::
vector
<
int
>
trans_axis
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
transpose
->
Op
()
->
GetAttr
(
"axis"
));
if
(
trans_axis
!=
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
})
{
VLOG
(
1
)
<<
"transpose layernorm fuse pass, transpose axis check fail, "
"stop fusion"
;
return
;
}
if
(
reshape
->
Op
()
->
Type
()
==
"flatten_contiguous_range"
)
{
int
start_axis
=
PADDLE_GET_CONST
(
int
,
reshape
->
Op
()
->
GetAttr
(
"start_axis"
));
int
stop_axis
=
PADDLE_GET_CONST
(
int
,
reshape
->
Op
()
->
GetAttr
(
"stop_axis"
));
if
(
!
(
start_axis
==
1
&&
stop_axis
==
2
))
{
VLOG
(
1
)
<<
"transpose layernorm fuse pass, flatten axis check fail, "
"stop fusion"
;
return
;
}
}
else
if
(
reshape
->
Op
()
->
Type
()
==
"reshape2"
)
{
std
::
vector
<
int
>
reshape_shape
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape
->
Op
()
->
GetAttr
(
"shape"
));
if
(
reshape_shape
.
size
()
!=
3
)
{
VLOG
(
1
)
<<
"transpose layernorm fuse pass, reshape check fail, stop fusion"
;
return
;
}
}
auto
layernorm_begin_norm_axis
=
PADDLE_GET_CONST
(
int
,
layernorm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
if
(
layernorm_begin_norm_axis
!=
2
&&
layernorm_begin_norm_axis
!=
-
1
)
{
VLOG
(
1
)
<<
"transpose layernorm fuse pass, layernorm begin norm axis "
"check fail, stop fusion"
;
return
;
}
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an preln_groupnorm_act op node
OpDesc
new_desc
(
*
layernorm
->
Op
());
new_desc
.
SetType
(
"trans_layernorm"
);
new_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
new_desc
.
SetOutput
(
"Out_reshape"
,
{
reshape_output
->
Name
()});
new_desc
.
SetOutput
(
"Out_layernorm"
,
{
layernorm_output
->
Name
()});
new_desc
.
RemoveOutput
(
"Y"
);
new_desc
.
Flush
();
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
del_node_set
.
insert
(
transpose
);
del_node_set
.
insert
(
transpose_output
);
del_node_set
.
insert
(
reshape
);
del_node_set
.
insert
(
layernorm
);
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
layernorm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
layernorm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
reshape_output
);
IR_NODE_LINK_TO
(
fused_node
,
layernorm_output
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
void
TransLayernormFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
"trans_layernorm_fuse_pass"
,
graph
);
int
found_subgraph_count
=
ApplyConvTransLayernormPattern
(
graph
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
trans_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
TransLayernormFusePass
);
paddle/fluid/framework/ir/trans_layernorm_fuse_pass.h
0 → 100644
浏览文件 @
b2bb7ec9
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// This pass aim to fuse below structure
//
// |
// transpose(axis= [0,2,3,1])
// |
// reshape(n,h*w,c)
// | |
// out layernorm(begin_norm_axis=2 or -1)
// |
// layernorm_out
//
// ->fuse to
//
// |
// trans_layernorm
// | |
// out layernorm_out
//
class
Graph
;
class
TransLayernormFusePass
:
public
FusePassBase
{
public:
TransLayernormFusePass
()
{
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"flatten_contiguous_range"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"start_axis"
)
.
IsNumEQ
(
1
)
.
End
()
.
AddAttr
(
"stop_axis"
)
.
IsNumEQ
(
2
)
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
();
}
virtual
~
TransLayernormFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
int
ApplyConvTransLayernormPattern
(
ir
::
Graph
*
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
b2bb7ec9
...
@@ -2490,6 +2490,7 @@ USE_TRT_CONVERTER(layernorm_shift_partition)
...
@@ -2490,6 +2490,7 @@ USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER
(
reverse_roll
)
USE_TRT_CONVERTER
(
reverse_roll
)
USE_TRT_CONVERTER
(
preln_layernorm_shift_partition
)
USE_TRT_CONVERTER
(
preln_layernorm_shift_partition
)
USE_TRT_CONVERTER
(
merge_layernorm
)
USE_TRT_CONVERTER
(
merge_layernorm
)
USE_TRT_CONVERTER
(
trans_layernorm
)
USE_TRT_CONVERTER
(
skip_merge_layernorm
)
USE_TRT_CONVERTER
(
skip_merge_layernorm
)
USE_TRT_CONVERTER
(
generic_plugin_creater
)
USE_TRT_CONVERTER
(
generic_plugin_creater
)
USE_TRT_CONVERTER
(
custom_plugin_creater
)
USE_TRT_CONVERTER
(
custom_plugin_creater
)
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
b2bb7ec9
...
@@ -127,6 +127,10 @@ const std::vector<std::string> kTRTSubgraphPasses({
...
@@ -127,6 +127,10 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"trans_layernorm_fuse_pass"
,
//
#endif
"remove_padding_recover_padding_pass"
,
//
"remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
// "yolo_box_fuse_pass", //
// "yolo_box_fuse_pass", //
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
b2bb7ec9
...
@@ -92,6 +92,7 @@ list(
...
@@ -92,6 +92,7 @@ list(
take_along_axis_op.cc
take_along_axis_op.cc
logsigmoid_op.cc
logsigmoid_op.cc
preln_layernorm_shift_partition_op.cc
preln_layernorm_shift_partition_op.cc
trans_layernorm_op.cc
merge_layernorm_op.cc
merge_layernorm_op.cc
skip_merge_layernorm_op.cc
skip_merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
generic_and_custom_plugin_creater.cc
...
...
paddle/fluid/inference/tensorrt/convert/trans_layernorm_op.cc
0 → 100644
浏览文件 @
b2bb7ec9
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TransLayerNormOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a trans_layer_norm fused op to tensorrt "
"trans_layernorm plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
X
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
).
front
());
auto
*
Bias_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
Scale_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Scale"
).
front
());
// we already check the begin_norm_axis in pass action.
// here we set begin_norm_axis as 3 to fit the calculation in trt plugin.
const
int
begin_norm_axis
=
3
;
const
float
eps
=
op_desc
.
HasAttr
(
"epsilon"
)
?
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
))
:
1e-5
f
;
PADDLE_ENFORCE_NOT_NULL
(
Bias_v
,
platform
::
errors
::
InvalidArgument
(
"Input(Bias) of layer_norm should not be null."
));
PADDLE_ENFORCE_NOT_NULL
(
Scale_v
,
platform
::
errors
::
InvalidArgument
(
"Input(Scale) of layer_norm should not be null."
));
auto
*
Bias_t
=
Bias_v
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
Scale_t
=
Scale_v
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
bias_weight
=
engine_
->
GetFp32TrtWeight
(
op_desc
.
Input
(
"Bias"
).
front
(),
*
Bias_t
);
auto
scale_weight
=
engine_
->
GetFp32TrtWeight
(
op_desc
.
Input
(
"Scale"
).
front
(),
*
Scale_t
);
nvinfer1
::
ILayer
*
layernorm_layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
// For dynamic shape,
// the shape of mean and variance will be determine in configuPlugin.
std
::
vector
<
int64_t
>
mean_shape
{
1
};
std
::
vector
<
int64_t
>
variance_shape
{
1
};
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
TransLayerNormPluginDynamic
*
plugin
=
new
plugin
::
TransLayerNormPluginDynamic
(
static_cast
<
const
float
*>
(
bias_weight
.
get
().
values
),
bias_weight
.
get
().
count
,
static_cast
<
const
float
*>
(
scale_weight
.
get
().
values
),
scale_weight
.
get
().
count
,
begin_norm_axis
,
eps
,
mean_shape
,
variance_shape
,
with_fp16
);
layernorm_layer
=
engine_
->
AddDynamicPlugin
(
&
X
,
1
,
plugin
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"trans_layernorm do not support static shape mode yet"
));
}
auto
output_layernorm_name
=
op_desc
.
Output
(
"Out_layernorm"
).
front
();
auto
output_reshape_name
=
op_desc
.
Output
(
"Out_reshape"
).
front
();
RreplenishLayerAndOutput
(
layernorm_layer
,
"trans_layernorm"
,
{
output_layernorm_name
,
output_reshape_name
},
test_mode
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
trans_layernorm
,
TransLayerNormOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
b2bb7ec9
...
@@ -2489,7 +2489,13 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -2489,7 +2489,13 @@ struct SimpleOpTypeSetTeller : public Teller {
return
false
;
return
false
;
}
}
}
}
if
(
op_type
==
"trans_layernorm"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"The trans_layernorm op does not support "
"static shape yet"
;
return
false
;
}
}
if
(
op_type
==
"lookup_table"
)
{
if
(
op_type
==
"lookup_table"
)
{
if
(
!
with_dynamic_shape
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"the lookup_table does not support "
VLOG
(
3
)
<<
"the lookup_table does not support "
...
@@ -2659,6 +2665,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -2659,6 +2665,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"logsigmoid"
,
"logsigmoid"
,
"preln_layernorm_shift_partition"
,
"preln_layernorm_shift_partition"
,
"lookup_table"
,
"lookup_table"
,
"trans_layernorm"
,
"merge_layernorm"
,
"merge_layernorm"
,
"skip_merge_layernorm"
,
"skip_merge_layernorm"
,
"lookup_table_v2"
,
"lookup_table_v2"
,
...
@@ -2808,6 +2815,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -2808,6 +2815,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"take_along_axis"
,
"take_along_axis"
,
"logsigmoid"
,
"logsigmoid"
,
"preln_layernorm_shift_partition"
,
"preln_layernorm_shift_partition"
,
"trans_layernorm"
,
"merge_layernorm"
,
"merge_layernorm"
,
"skip_merge_layernorm"
,
"skip_merge_layernorm"
,
"lookup_table"
,
"lookup_table"
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
b2bb7ec9
...
@@ -33,6 +33,7 @@ list(
...
@@ -33,6 +33,7 @@ list(
layernorm_shift_partition_op.cu
layernorm_shift_partition_op.cu
reverse_roll_op_plugin.cu
reverse_roll_op_plugin.cu
prelnlayernorm_shift_partition_op.cu
prelnlayernorm_shift_partition_op.cu
trans_layernorm_op_plugin.cu
merge_layernorm_op_plugin.cu
merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu
skip_groupnorm_act_op_plugin.cu
skip_groupnorm_act_op_plugin.cu
...
...
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
浏览文件 @
b2bb7ec9
...
@@ -26,12 +26,21 @@
...
@@ -26,12 +26,21 @@
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
inline
int
getSMVersion
()
{
const
int
device
=
phi
::
backends
::
gpu
::
GetCurrentDeviceId
();
const
phi
::
gpuDeviceProp
prop
=
phi
::
backends
::
gpu
::
GetDeviceProperties
(
device
);
return
prop
.
major
*
10
+
prop
.
minor
;
}
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff
#define FINAL_MASK 0xffffffff
...
@@ -425,7 +434,10 @@ int PrelnResidualBiasPluginDynamic::enqueue(
...
@@ -425,7 +434,10 @@ int PrelnResidualBiasPluginDynamic::enqueue(
float
*
mean
=
nullptr
;
float
*
mean
=
nullptr
;
float
*
var
=
nullptr
;
float
*
var
=
nullptr
;
const
int
VecSize
=
8
;
const
int
VecSize
=
8
;
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
sm
=
getSMVersion
();
// sm >= 60 to support _ldg
if
(
sm
>=
60
)
{
// if hidden is even, use half2 kernel generalAddBiasResidualLayerNormOpt2
// if hidden is even, use half2 kernel generalAddBiasResidualLayerNormOpt2
if
(
hidden
%
2
==
0
)
{
if
(
hidden
%
2
==
0
)
{
int
half_n
=
hidden
/
2
;
int
half_n
=
hidden
/
2
;
...
@@ -479,7 +491,8 @@ int PrelnResidualBiasPluginDynamic::enqueue(
...
@@ -479,7 +491,8 @@ int PrelnResidualBiasPluginDynamic::enqueue(
var
,
var
,
stream
);
stream
);
}
}
#else
}
else
{
// if sm < 60, use FusedLayernormResidualDropoutBiasFunctor only
paddle
::
operators
::
FusedLayernormResidualDropoutBiasFunctor
<
half
,
paddle
::
operators
::
FusedLayernormResidualDropoutBiasFunctor
<
half
,
uint8_t
,
uint8_t
,
VecSize
,
VecSize
,
...
@@ -504,7 +517,7 @@ int PrelnResidualBiasPluginDynamic::enqueue(
...
@@ -504,7 +517,7 @@ int PrelnResidualBiasPluginDynamic::enqueue(
mean
,
mean
,
var
,
var
,
stream
);
stream
);
#endif
}
#else
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"The Ernie(Bert) tensorRT plugin should be "
...
...
paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu
0 → 100644
浏览文件 @
b2bb7ec9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
inline
int
getSMVersion
()
{
const
int
device
=
phi
::
backends
::
gpu
::
GetCurrentDeviceId
();
const
phi
::
gpuDeviceProp
prop
=
phi
::
backends
::
gpu
::
GetDeviceProperties
(
device
);
return
prop
.
major
*
10
+
prop
.
minor
;
}
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff
template
<
int
UNROLL_FACTOR
>
__global__
void
GeneralResidualLayerNormOpt2
(
half2
*
normed_output
,
half2
*
output
,
const
half2
*
__restrict
src
,
const
half2
*
__restrict
gamma
,
const
half2
*
__restrict
beta
,
int
m
,
int
n
,
float
epsilon
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
__shared__
float
s_mean
;
__shared__
float
s_variance
;
float
x_sum
=
0.0
f
;
float
x2_sum
=
0.0
f
;
const
int
b_offset
=
blockIdx
.
x
*
n
;
#pragma unroll UNROLL_FACTOR
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
index
=
b_offset
+
i
;
float
val_1
=
0.0
f
;
float
val_2
=
0.0
f
;
half2
tmp
;
tmp
=
__ldg
(
&
src
[
index
]);
val_1
+=
static_cast
<
float
>
(
tmp
.
x
);
val_2
+=
static_cast
<
float
>
(
tmp
.
y
);
output
[
index
]
=
tmp
;
x_sum
+=
val_1
+
val_2
;
x2_sum
+=
val_1
*
val_1
+
val_2
*
val_2
;
}
float
sums
[
2
];
sums
[
0
]
=
x_sum
;
sums
[
1
]
=
x2_sum
;
phi
::
funcs
::
BlockReduceSumV2
<
float
,
2
>
(
sums
);
constexpr
int
Half2VecSize
=
2
;
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sums
[
0
]
/
n
/
Half2VecSize
;
s_variance
=
rsqrtf
(
sums
[
1
]
/
n
/
Half2VecSize
-
s_mean
*
s_mean
+
epsilon
);
}
__syncthreads
();
half2
mean_2
=
__float2half2_rn
(
s_mean
);
half2
var_2
=
__float2half2_rn
(
s_variance
);
#pragma unroll UNROLL_FACTOR
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
index
=
b_offset
+
i
;
half2
val
=
__hmul2
(
__hmul2
(
__hsub2
(
output
[
index
],
mean_2
),
var_2
),
__ldg
(
&
gamma
[
i
]));
if
(
beta
)
{
val
=
__hadd2
(
val
,
__ldg
(
&
beta
[
i
]));
}
normed_output
[
index
]
=
val
;
}
#endif
}
#define HALF2_RESIDUAL_LAYERNORM_OPT2(UNROLL_FACTOR) \
GeneralResidualLayerNormOpt2<UNROLL_FACTOR> \
<<<rows, block, 0, stream>>>(reinterpret_cast<half2 *>(layernorm_dst), \
reinterpret_cast<half2 *>(dst), \
(const half2 *)input, \
(const half2 *)fp16_scale_gpu_, \
(const half2 *)fp16_bias_gpu_, \
rows, \
half_n, \
eps);
#endif
int
TransLayerNormPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
if
(
!
with_fp16_
)
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
float
)
*
bias_
.
size
());
cudaMemcpy
(
bias_gpu_
,
bias_
.
data
(),
bias_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
float
)
*
scale_
.
size
());
cudaMemcpy
(
scale_gpu_
,
scale_
.
data
(),
scale_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
else
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
float
)
*
bias_
.
size
());
cudaMemcpy
(
bias_gpu_
,
bias_
.
data
(),
bias_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
float
)
*
scale_
.
size
());
cudaMemcpy
(
scale_gpu_
,
scale_
.
data
(),
scale_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
std
::
vector
<
half
>
fp16_bias_
;
std
::
vector
<
half
>
fp16_scale_
;
fp16_bias_
.
resize
(
bias_
.
size
());
fp16_scale_
.
resize
(
scale_
.
size
());
for
(
int
i
=
0
;
i
<
bias_
.
size
();
i
++
)
{
fp16_bias_
[
i
]
=
static_cast
<
half
>
(
bias_
[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_
.
size
();
i
++
)
{
fp16_scale_
[
i
]
=
static_cast
<
half
>
(
scale_
[
i
]);
}
cudaMalloc
(
&
fp16_bias_gpu_
,
sizeof
(
half
)
*
fp16_bias_
.
size
());
cudaMemcpy
(
fp16_bias_gpu_
,
fp16_bias_
.
data
(),
fp16_bias_
.
size
()
*
sizeof
(
half
),
cudaMemcpyHostToDevice
);
cudaMalloc
(
&
fp16_scale_gpu_
,
sizeof
(
half
)
*
fp16_scale_
.
size
());
cudaMemcpy
(
fp16_scale_gpu_
,
fp16_scale_
.
data
(),
fp16_scale_
.
size
()
*
sizeof
(
half
),
cudaMemcpyHostToDevice
);
}
return
0
;
}
void
TransLayerNormPluginDynamic
::
terminate
()
TRT_NOEXCEPT
{
if
(
bias_gpu_
)
{
cudaFree
(
bias_gpu_
);
bias_gpu_
=
nullptr
;
}
if
(
scale_gpu_
)
{
cudaFree
(
scale_gpu_
);
scale_gpu_
=
nullptr
;
}
if
(
fp16_bias_gpu_
)
{
cudaFree
(
fp16_bias_gpu_
);
fp16_bias_gpu_
=
nullptr
;
}
if
(
fp16_scale_gpu_
)
{
cudaFree
(
fp16_scale_gpu_
);
fp16_scale_gpu_
=
nullptr
;
}
}
nvinfer1
::
DimsExprs
TransLayerNormPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputDims
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
3
;
ret
.
d
[
0
]
=
inputDims
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
inputDims
[
0
].
d
[
2
],
*
inputDims
[
0
].
d
[
3
]);
ret
.
d
[
2
]
=
inputDims
[
0
].
d
[
1
];
return
ret
;
}
bool
TransLayerNormPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
int
feature_size
=
bias_
.
size
();
PADDLE_ENFORCE_GE
(
feature_size
,
0
,
platform
::
errors
::
InvalidArgument
(
"The feature size of layernorm feature_size must be positive,"
"but got:%d"
,
feature_size
));
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of layernorm plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
if
(
feature_size
%
8
==
0
)
{
// now, we only support khwc8 for feature_size % 8 == 0
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
||
in
.
format
==
nvinfer1
::
PluginFormat
::
kHWC8
));
}
else
{
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
if
(
pos
==
1
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
in_out
[
0
].
type
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
else
{
return
(
in
.
type
==
in_out
[
0
].
type
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
if
(
pos
==
2
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
in_out
[
0
].
type
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
else
{
return
(
in
.
type
==
in_out
[
0
].
type
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
}
void
TransLayerNormPluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
begin_norm_axis_
,
3
,
platform
::
errors
::
InvalidArgument
(
"The transpose_LayerNorm Plugin only has begin_norm_axis_ = 3"
"but get %d."
,
begin_norm_axis_
));
const
auto
&
input_dims
=
in
[
0
].
desc
.
dims
;
int
statis_num
=
input_dims
.
d
[
0
]
*
input_dims
.
d
[
2
]
*
input_dims
.
d
[
3
];
mean_shape_
[
0
]
=
statis_num
;
variance_shape_
[
0
]
=
statis_num
;
}
nvinfer1
::
DataType
TransLayerNormPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nb_inputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The transpose_LayerNorm Plugin only has one input, so the "
"nb_inputs value should be 1, but get %d."
,
nb_inputs
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kHALF
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input type should be half or float"
));
return
input_types
[
0
];
}
int
TransLayerNormPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
&
input_dims
=
input_desc
[
0
].
dims
;
int
begin_norm_axis
=
begin_norm_axis_
;
float
eps
=
eps_
;
std
::
vector
<
int
>
input_shape
;
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
{
input_shape
.
push_back
(
input_dims
.
d
[
i
]);
}
int
input_numel
=
1
;
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
{
input_numel
*=
input_dims
.
d
[
i
];
}
PADDLE_ENFORCE_EQ
(
1
,
mean_shape_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Size of mean_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d"
,
mean_shape_
.
size
()));
PADDLE_ENFORCE_EQ
(
1
,
variance_shape_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Size of variance_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d"
,
mean_shape_
.
size
()));
PADDLE_ENFORCE_GE
(
mean_shape_
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"The size of mean vector should be positive,"
"but got:%d"
,
mean_shape_
[
0
]));
PADDLE_ENFORCE_GE
(
variance_shape_
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"The size of mean vector should be positive,"
"but got:%d"
,
variance_shape_
[
0
]));
// transpose do not change numel
int
trans_result_numel
=
input_numel
;
std
::
vector
<
int
>
trans_result_shape
{
input_shape
[
0
],
input_shape
[
2
],
input_shape
[
3
],
input_shape
[
1
]};
const
auto
input_ddim
=
phi
::
make_ddim
(
input_shape
);
int
feature_size
=
static_cast
<
int
>
(
input_ddim
[
1
]);
PADDLE_ENFORCE_EQ
(
feature_size
,
scale_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"scale's size should be equal to the feature_size,"
"but got feature_size:%d, scale's size:%d."
,
feature_size
,
scale_
.
size
()));
PADDLE_ENFORCE_EQ
(
feature_size
,
bias_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"bias's size should be equal to the feature_size,"
"but got feature_size:%d, bias's size:%d."
,
feature_size
,
bias_
.
size
()));
int
device_id
=
-
1
;
cudaGetDevice
(
&
device_id
);
PADDLE_ENFORCE_GE
(
device_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"device_id should be positive,"
"but got:%d"
,
device_id
));
mean_t
.
Resize
(
phi
::
make_ddim
(
mean_shape_
));
variance_t
.
Resize
(
phi
::
make_ddim
(
variance_shape_
));
float
*
mean_d
=
mean_t
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
float
*
variance_d
=
variance_t
.
mutable_data
<
float
>
(
platform
::
CUDAPlace
(
device_id
));
auto
input_type
=
input_desc
[
0
].
type
;
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
platform
::
CUDAPlace
place
(
platform
::
GetCurrentDeviceId
());
auto
*
device_context
=
static_cast
<
phi
::
GPUContext
*>
(
pool
.
Get
(
place
));
const
phi
::
GPUContext
&
dev_ctx
=
*
device_context
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. trans_layernorm-->fp32"
;
VLOG
(
1
)
<<
"TRT Plugin format selected. trans_layernorm-->kLINEAR"
;
const
float
*
input
=
reinterpret_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
layernorm_dst
=
static_cast
<
float
*>
(
outputs
[
0
]);
float
*
dst
=
static_cast
<
float
*>
(
outputs
[
1
]);
// transpose and norm do not change numel
int
trans_result_numel
=
input_numel
;
int
norm_result_numel
=
input_numel
;
phi
::
DenseTensorMeta
input_meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
make_ddim
(
input_shape
));
phi
::
DenseTensorMeta
bias_meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
make_ddim
({
feature_size
}));
phi
::
DenseTensorMeta
scale_meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
make_ddim
({
feature_size
}));
phi
::
DenseTensorMeta
trans_result_meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
make_ddim
(
trans_result_shape
));
phi
::
DenseTensorMeta
norm_result_meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
make_ddim
(
trans_result_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
input_alloc
(
new
phi
::
Allocation
(
static_cast
<
void
*>
(
const_cast
<
float
*>
(
input
)),
// NOLINT
input_numel
*
sizeof
(
float
),
place
));
std
::
shared_ptr
<
phi
::
Allocation
>
bias_alloc
(
new
phi
::
Allocation
(
static_cast
<
float
*>
(
bias_gpu_
),
// NOLINT
feature_size
*
sizeof
(
float
),
place
));
std
::
shared_ptr
<
phi
::
Allocation
>
scale_alloc
(
new
phi
::
Allocation
(
static_cast
<
float
*>
(
scale_gpu_
),
feature_size
*
sizeof
(
float
),
place
));
std
::
shared_ptr
<
phi
::
Allocation
>
trans_result_alloc
(
new
phi
::
Allocation
(
static_cast
<
float
*>
(
dst
),
// NOLINT
trans_result_numel
*
sizeof
(
float
),
place
));
std
::
shared_ptr
<
phi
::
Allocation
>
norm_result_alloc
(
new
phi
::
Allocation
(
static_cast
<
float
*>
(
layernorm_dst
),
// NOLINT
norm_result_numel
*
sizeof
(
float
),
place
));
const
phi
::
DenseTensor
input_tensor
=
phi
::
DenseTensor
(
input_alloc
,
input_meta
);
phi
::
DenseTensor
bias_tensor
=
phi
::
DenseTensor
(
bias_alloc
,
bias_meta
);
phi
::
DenseTensor
scale_tensor
=
phi
::
DenseTensor
(
scale_alloc
,
scale_meta
);
phi
::
DenseTensor
trans_result_tensor
=
phi
::
DenseTensor
(
trans_result_alloc
,
trans_result_meta
);
phi
::
DenseTensor
norm_result_tensor
=
phi
::
DenseTensor
(
norm_result_alloc
,
norm_result_meta
);
phi
::
TransposeKernel
<
float
,
phi
::
GPUContext
>
(
dev_ctx
,
input_tensor
,
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
},
&
trans_result_tensor
);
phi
::
LayerNormKernel
<
float
,
phi
::
GPUContext
>
(
dev_ctx
,
trans_result_tensor
,
scale_tensor
,
bias_tensor
,
eps
,
begin_norm_axis
,
&
norm_result_tensor
,
&
mean_t
,
&
variance_t
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. trans_layernorm-->fp16"
;
const
half
*
input
=
reinterpret_cast
<
const
half
*>
(
inputs
[
0
]);
half
*
layernorm_dst
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
*
dst
=
static_cast
<
half
*>
(
outputs
[
1
]);
if
(
input_desc
[
0
].
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
)
{
VLOG
(
1
)
<<
"TRT Plugin format selected. trans_layernorm-->kLINEAR"
;
phi
::
DenseTensorMeta
input_meta
(
phi
::
DataType
::
FLOAT16
,
phi
::
make_ddim
(
input_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
input_alloc
(
new
phi
::
Allocation
(
static_cast
<
void
*>
(
const_cast
<
half
*>
(
input
)),
// NOLINT
input_numel
*
sizeof
(
half
),
place
));
phi
::
DenseTensorMeta
trans_result_meta
(
phi
::
DataType
::
FLOAT16
,
phi
::
make_ddim
(
trans_result_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
trans_result_alloc
(
new
phi
::
Allocation
(
static_cast
<
void
*>
(
dst
),
// NOLINT
trans_result_numel
*
sizeof
(
half
),
place
));
const
phi
::
DenseTensor
input_tensor
=
phi
::
DenseTensor
(
input_alloc
,
input_meta
);
phi
::
DenseTensor
trans_result_tensor
=
phi
::
DenseTensor
(
trans_result_alloc
,
trans_result_meta
);
phi
::
TransposeKernel
<
phi
::
dtype
::
float16
,
phi
::
GPUContext
>
(
dev_ctx
,
input_tensor
,
std
::
vector
<
int
>
{
0
,
2
,
3
,
1
},
&
trans_result_tensor
);
phi
::
LayerNormDirectCUDAFunctor
<
half
,
float
>
layer_norm
;
layer_norm
(
stream
,
dst
,
trans_result_shape
,
bias_gpu_
,
scale_gpu_
,
layernorm_dst
,
mean_d
,
variance_d
,
begin_norm_axis
,
eps
);
}
else
if
(
input_desc
[
0
].
format
==
nvinfer1
::
PluginFormat
::
kHWC8
)
{
VLOG
(
1
)
<<
"TRT Plugin format selected. trans_layernorm-->kHWC8"
;
int
sm
=
getSMVersion
();
// sm >= 60 to support __ldg
if
(
sm
>=
60
)
{
int
hidden
=
input_shape
[
1
];
if
(
hidden
%
2
==
0
)
{
const
size_t
rows
=
static_cast
<
size_t
>
(
input_shape
[
0
]
*
input_shape
[
2
]
*
input_shape
[
3
]);
// batch * seq_length
int
half_n
=
hidden
/
2
;
int
half_n_32
=
(
half_n
+
31
)
/
32
*
32
;
dim3
block
(
std
::
min
(
half_n_32
,
512
));
int
rolls_per_thread
=
half_n
/
block
.
x
;
int
unroll_factor
=
8
;
while
(
unroll_factor
>
rolls_per_thread
&&
unroll_factor
>
1
)
{
unroll_factor
/=
2
;
}
switch
(
unroll_factor
)
{
case
1
:
HALF2_RESIDUAL_LAYERNORM_OPT2
(
1
);
break
;
case
2
:
HALF2_RESIDUAL_LAYERNORM_OPT2
(
2
);
break
;
case
4
:
HALF2_RESIDUAL_LAYERNORM_OPT2
(
4
);
break
;
case
8
:
HALF2_RESIDUAL_LAYERNORM_OPT2
(
8
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Invalid UNROLL_FACTOR in transpose_layernorm trt plugin."
));
}
}
else
{
cudaMemcpyAsync
(
dst
,
input
,
input_numel
*
sizeof
(
half
),
cudaMemcpyDeviceToDevice
);
phi
::
LayerNormDirectCUDAFunctor
<
half
,
float
>
layer_norm
;
layer_norm
(
stream
,
input
,
trans_result_shape
,
bias_gpu_
,
scale_gpu_
,
layernorm_dst
,
mean_d
,
variance_d
,
begin_norm_axis
,
eps
);
}
}
else
{
cudaMemcpyAsync
(
dst
,
input
,
input_numel
*
sizeof
(
half
),
cudaMemcpyDeviceToDevice
);
phi
::
LayerNormDirectCUDAFunctor
<
half
,
float
>
layer_norm
;
layer_norm
(
stream
,
input
,
trans_result_shape
,
bias_gpu_
,
scale_gpu_
,
layernorm_dst
,
mean_d
,
variance_d
,
begin_norm_axis
,
eps
);
}
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The TransLayerNormPluginDynamic TRT Plugin's "
"input type should be float or half."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.h
0 → 100644
浏览文件 @
b2bb7ec9
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
TransLayerNormPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
TransLayerNormPluginDynamic
(
const
float
*
bias
,
const
int
bias_num
,
const
float
*
scale
,
const
int
scale_num
,
int
begin_norm_axis
,
float
eps
,
std
::
vector
<
int64_t
>
mean_shape
,
std
::
vector
<
int64_t
>
variance_shape
,
bool
with_fp16
)
:
begin_norm_axis_
(
begin_norm_axis
),
eps_
(
eps
),
mean_shape_
(
mean_shape
),
variance_shape_
(
variance_shape
)
{
with_fp16_
=
with_fp16
;
bias_
.
resize
(
bias_num
);
scale_
.
resize
(
scale_num
);
std
::
copy
(
bias
,
bias
+
bias_num
,
bias_
.
data
());
std
::
copy
(
scale
,
scale
+
scale_num
,
scale_
.
data
());
}
TransLayerNormPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
bias_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
scale_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
begin_norm_axis_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
eps_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
mean_shape_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
variance_shape_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
ptr
=
new
TransLayerNormPluginDynamic
(
bias_
.
data
(),
bias_
.
size
(),
scale_
.
data
(),
scale_
.
size
(),
begin_norm_axis_
,
eps_
,
mean_shape_
,
variance_shape_
,
with_fp16_
);
ptr
->
bias_gpu_
=
bias_gpu_
;
ptr
->
scale_gpu_
=
scale_gpu_
;
ptr
->
fp16_bias_gpu_
=
fp16_bias_gpu_
;
ptr
->
fp16_scale_gpu_
=
fp16_scale_gpu_
;
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"trans_layernorm_plugin_dynamic"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
2
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
SerializedSize
(
bias_
)
+
SerializedSize
(
scale_
)
+
SerializedSize
(
begin_norm_axis_
)
+
SerializedSize
(
eps_
)
+
SerializedSize
(
mean_shape_
)
+
SerializedSize
(
variance_shape_
)
+
SerializedSize
(
with_fp16_
);
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
SerializeValue
(
&
buffer
,
bias_
);
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
begin_norm_axis_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
mean_shape_
);
SerializeValue
(
&
buffer
,
variance_shape_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
private:
std
::
vector
<
float
>
bias_
;
std
::
vector
<
float
>
scale_
;
phi
::
DenseTensor
mean_t
;
phi
::
DenseTensor
variance_t
;
int
begin_norm_axis_
;
float
eps_
;
std
::
vector
<
int64_t
>
mean_shape_
;
std
::
vector
<
int64_t
>
variance_shape_
;
// data on devices
float
*
bias_gpu_
{
nullptr
};
float
*
scale_gpu_
{
nullptr
};
half
*
fp16_bias_gpu_
{
nullptr
};
half
*
fp16_scale_gpu_
{
nullptr
};
};
class
TransLayerNormPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"trans_layernorm_plugin_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
TransLayerNormPluginDynamic
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
TransLayerNormPluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
b2bb7ec9
...
@@ -31,6 +31,9 @@ if(WIN32)
...
@@ -31,6 +31,9 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_element_groupnorm_act_fuse_pass"
)
"test_element_groupnorm_act_fuse_pass"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_groupnorm_act_pass_fuse_pass"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_groupnorm_act_pass_fuse_pass"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_trans_layernorm"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_trans_layernorm"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_trans_layernorm"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_fused_token_prune"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_fused_token_prune"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_fused_token_prune"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_fused_token_prune"
)
endif
()
endif
()
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_trans_layernorm.py
0 → 100644
浏览文件 @
b2bb7ec9
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
from
typing
import
List
import
numpy
as
np
from
program_config
import
ProgramConfig
,
TensorConfig
from
trt_layer_auto_scan_test
import
TrtLayerAutoScanTest
import
paddle.inference
as
paddle_infer
class
TrtConvertTransLayernormTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
return
True
def
sample_program_configs
(
self
):
def
conv_filter_datagen
(
dics
):
c
=
dics
[
"c"
]
x
=
(
np
.
random
.
randn
(
c
,
c
,
1
,
1
))
/
np
.
sqrt
(
c
)
return
x
.
astype
(
np
.
float32
)
def
elementwise_bias_datagen
(
dics
):
c
=
dics
[
"c"
]
x
=
np
.
random
.
random
([
c
])
*
0.01
return
x
.
astype
(
np
.
float32
)
def
layernorm_bias_datagen
(
dics
):
c
=
dics
[
"c"
]
x
=
np
.
random
.
random
([
c
])
*
0.1
return
x
.
astype
(
np
.
float32
)
def
layernorm_scale_datagen
(
dics
):
x
=
np
.
ones
([
c
])
return
x
.
astype
(
np
.
float32
)
def
conv2d_input_datagen
(
dics
):
x
=
np
.
random
.
randn
(
dics
[
"batch"
],
dics
[
"c"
],
dics
[
"h"
],
dics
[
"w"
])
x
=
(
x
-
np
.
mean
(
x
))
/
(
np
.
std
(
x
))
return
x
.
astype
(
np
.
float32
)
for
batch
in
[
2
]:
for
begin_norm_axis
in
[
2
]:
for
h
in
[
32
,
64
]:
for
w
in
[
32
,
64
]:
for
c
in
[
128
,
320
,
255
,
133
]:
for
reshape
in
[
"flatten"
,
"reshape"
]:
dics
=
{
"batch"
:
batch
,
"begin_norm_axis"
:
begin_norm_axis
,
"h"
:
h
,
"w"
:
w
,
"c"
:
c
,
"flatten"
:
{
"op_type"
:
"flatten_contiguous_range"
,
"op_inputs"
:
{
"X"
:
[
"transpose2_out"
],
},
"op_outputs"
:
{
"Out"
:
[
"reshape_out"
],
},
"op_attrs"
:
{
"start_axis"
:
1
,
"stop_axis"
:
2
,
},
},
"reshape"
:
{
"op_type"
:
"reshape2"
,
"op_inputs"
:
{
"X"
:
[
"transpose2_out"
],
},
"op_outputs"
:
{
"Out"
:
[
"reshape_out"
],
},
"op_attrs"
:
{
"shape"
:
[
-
1
,
h
*
w
,
c
]},
},
}
ops_config
=
[
{
"op_type"
:
"conv2d"
,
"op_inputs"
:
{
"Input"
:
[
"conv2d_input"
],
"Filter"
:
[
"conv2d_filter"
],
},
"op_outputs"
:
{
"Output"
:
[
"conv2d_output"
],
},
"op_attrs"
:
{
"dilations"
:
[
1
,
1
],
"padding_algorithm"
:
"EXPLICIT"
,
"groups"
:
1
,
"paddings"
:
[
0
,
0
],
"strides"
:
[
1
,
1
],
"data_format"
:
"NCHW"
,
},
},
{
"op_type"
:
"elementwise_add"
,
"op_inputs"
:
{
"X"
:
[
"conv2d_output"
],
"Y"
:
[
"elementwise_bias"
],
},
"op_outputs"
:
{
"Out"
:
[
"elementwise_out"
]
},
"op_attrs"
:
{
"axis"
:
1
},
},
{
"op_type"
:
"transpose2"
,
"op_inputs"
:
{
"X"
:
[
"elementwise_out"
],
},
"op_outputs"
:
{
"Out"
:
[
"transpose2_out"
],
},
"op_attrs"
:
{
"axis"
:
[
0
,
2
,
3
,
1
]},
},
dics
[
reshape
],
{
"op_type"
:
"layer_norm"
,
"op_inputs"
:
{
"X"
:
[
"reshape_out"
],
"Bias"
:
[
"layernorm_bias"
],
"Scale"
:
[
"layernorm_scale"
],
},
"op_outputs"
:
{
"Y"
:
[
"layernorm_out"
],
"Mean"
:
[
"layernorm_mean"
],
"Variance"
:
[
"layernorm_variance"
],
},
"op_attrs"
:
{
"epsilon"
:
1e-5
,
"begin_norm_axis"
:
dics
[
"begin_norm_axis"
],
},
},
]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{
"conv2d_filter"
:
TensorConfig
(
data_gen
=
partial
(
conv_filter_datagen
,
dics
)
),
"elementwise_bias"
:
TensorConfig
(
data_gen
=
partial
(
elementwise_bias_datagen
,
dics
)
),
"layernorm_bias"
:
TensorConfig
(
data_gen
=
partial
(
layernorm_bias_datagen
,
dics
)
),
"layernorm_scale"
:
TensorConfig
(
data_gen
=
partial
(
layernorm_scale_datagen
,
dics
)
),
},
inputs
=
{
"conv2d_input"
:
TensorConfig
(
data_gen
=
partial
(
conv2d_input_datagen
,
dics
)
),
},
outputs
=
[
"reshape_out"
,
"layernorm_out"
],
)
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
,
inputs
):
conv2d_c
=
inputs
[
'conv2d_input'
].
shape
[
1
]
self
.
dynamic_shape
.
min_input_shape
=
{
"conv2d_input"
:
[
1
,
conv2d_c
,
32
,
32
],
"conv2d_filter"
:
[
conv2d_c
,
conv2d_c
,
1
,
1
],
"elementwise_bias"
:
[
conv2d_c
],
"layernorm_bias"
:
[
conv2d_c
],
"layernorm_scale"
:
[
conv2d_c
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"conv2d_input"
:
[
4
,
conv2d_c
,
64
,
64
],
"conv2d_filter"
:
[
conv2d_c
,
conv2d_c
,
1
,
1
],
"elementwise_bias"
:
[
conv2d_c
],
"layernorm_bias"
:
[
conv2d_c
],
"layernorm_scale"
:
[
conv2d_c
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"conv2d_input"
:
[
4
,
conv2d_c
,
64
,
64
],
"conv2d_filter"
:
[
conv2d_c
,
conv2d_c
,
1
,
1
],
"elementwise_bias"
:
[
conv2d_c
],
"layernorm_bias"
:
[
conv2d_c
],
"layernorm_scale"
:
[
conv2d_c
],
}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
return
1
,
3
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
inputs
=
program_config
.
inputs
# just support dynamic_shape
generate_dynamic_shape
(
attrs
,
inputs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
(
1e-2
,
1e-2
,
)
# tol 1e-2 for half
def
add_skip_trt_case
(
self
):
pass
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录