Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
22bfa579
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2320
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
22bfa579
编写于
12月 08, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] General optimization for no_varlen embedding layernorm (#48580)
* general optimization no_varlen embedding layernorm
上级
8c416653
变更
20
展开全部
隐藏空白更改
内联
并排
Showing
20 changed file
with
1357 addition
and
963 deletion
+1357
-963
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+2
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+0
-3
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+25
-31
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+79
-118
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+3
-3
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+6
-5
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
...inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
+0
-291
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
.../inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
+0
-446
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu
...id/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu
+475
-0
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel_hface.cu
...sorrt/plugin/many_emb_Layernorm_varseqlen_kernel_hface.cu
+0
-6
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel_mtron.cu
...sorrt/plugin/many_emb_Layernorm_varseqlen_kernel_mtron.cu
+0
-6
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu
...id/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu
+497
-0
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h
...uid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h
+203
-0
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
...ce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
+54
-36
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
...nce/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
+0
-1
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc
...api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc
+3
-1
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h
.../api/trt_dynamic_shape_ernie_serialize_deserialize_test.h
+4
-8
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
...fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
+4
-4
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
22bfa579
...
...
@@ -140,7 +140,7 @@ if(WITH_TENSORRT)
pass_library
(
preln_layernorm_x_fuse_pass inference
)
endif
()
if
(
WITH_TENSORRT
AND NOT WIN32
)
if
(
WITH_TENSORRT
)
pass_library
(
trt_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
endif
()
...
...
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
浏览文件 @
22bfa579
...
...
@@ -1170,14 +1170,14 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen_trt_multihead_matmul_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id
, set mask_id
. Please "
"pos_id. Please "
"reconfig"
));
}
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
22bfa579
...
...
@@ -2338,11 +2338,8 @@ USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER
(
mish
);
USE_TRT_CONVERTER
(
deformable_conv
);
USE_TRT_CONVERTER
(
pool3d
)
#ifdef _WIN32
#else
USE_TRT_CONVERTER
(
fused_preln_embedding_eltwise_layernorm
)
USE_TRT_CONVERTER
(
fused_embedding_eltwise_layernorm
);
#endif
USE_TRT_CONVERTER
(
preln_skip_layernorm
)
USE_TRT_CONVERTER
(
preln_residual_bias
)
USE_TRT_CONVERTER
(
c_allreduce_sum
)
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
22bfa579
...
...
@@ -95,39 +95,33 @@ const std::vector<std::string> kTRTSubgraphPasses({
"identity_scale_op_clean_pass"
,
//
"add_support_int8_pass"
,
//
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass"
,
//
#if defined _WIN32
#else
"simplify_with_basic_ops_pass"
,
//
"trt_embedding_eltwise_layernorm_fuse_pass"
,
//
"preln_embedding_eltwise_layernorm_fuse_pass"
,
//
#endif
"delete_c_identity_op_pass"
,
//
"trt_multihead_matmul_fuse_pass_v2"
,
//
"trt_multihead_matmul_fuse_pass_v3"
,
//
"multihead_matmul_roformer_fuse_pass"
,
//
"constant_folding_pass"
,
//
"vit_attention_fuse_pass"
,
//
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
"layernorm_shift_partition_fuse_pass"
,
//
"merge_layernorm_fuse_pass"
,
//
"preln_residual_bias_fuse_pass"
,
//
"preln_layernorm_x_fuse_pass"
,
//
"reverse_roll_fuse_pass"
,
//
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
"trt_flatten2_matmul_fuse_pass"
,
//
"trt_map_matmul_v2_to_mul_pass"
,
//
"trt_map_matmul_v2_to_matmul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
"delete_c_identity_op_pass"
,
//
"trt_multihead_matmul_fuse_pass_v2"
,
//
"trt_multihead_matmul_fuse_pass_v3"
,
//
"multihead_matmul_roformer_fuse_pass"
,
//
"constant_folding_pass"
,
//
"vit_attention_fuse_pass"
,
//
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
"layernorm_shift_partition_fuse_pass"
,
//
"merge_layernorm_fuse_pass"
,
//
"preln_residual_bias_fuse_pass"
,
//
"preln_layernorm_x_fuse_pass"
,
//
"reverse_roll_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
"trt_flatten2_matmul_fuse_pass"
,
//
"trt_map_matmul_v2_to_mul_pass"
,
//
"trt_map_matmul_v2_to_matmul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
// "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass"
,
//
"dense_multihead_matmul_to_sparse_pass"
,
//
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
22bfa579
...
...
@@ -94,7 +94,7 @@ list(
fused_lookup_tables_op.cc
expand_v2_op.cc
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
AND NOT WIN32
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
list
(
APPEND CONVERT_FILES emb_eltwise_layernorm.cc
preln_emb_eltwise_layernorm.cc
)
endif
()
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
22bfa579
...
...
@@ -13,7 +13,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/
emb_eltwise
_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/
many_emb
_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
#include "paddle/phi/core/ddim.h"
...
...
@@ -36,7 +36,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert fluid EmbEltwiseLayerNorm op to tensorrt layer"
;
// get the presistable var's data
auto
GetWeight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dim
)
->
TensorRTEngine
::
Weight
{
...
...
@@ -47,32 +46,13 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
return
weight
;
};
auto
GetFp16Weight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dim
)
->
TensorRTEngine
::
Weight
{
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
*
dim
=
temp_tensor
->
dims
();
auto
weight
=
engine_
->
GetFp16TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
auto
GetFp32Weight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dim
)
->
TensorRTEngine
::
Weight
{
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
*
dim
=
temp_tensor
->
dims
();
auto
weight
=
engine_
->
GetFp32TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
pos_id_name
=
engine_
->
tensorrt_transformer_posid
();
auto
mask_id_name
=
engine_
->
tensorrt_transformer_maskid
();
bool
flag_varseqlen
=
engine_
->
use_varseqlen
()
&&
pos_id_name
!=
""
&&
mask_id_name
!=
""
;
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
hidden
=
0
;
// Declare inputs
// bool with_fp16 = engine_->WithFp16() &&
// !engine_->disable_trt_plugin_fp16(); int hidden = 0; Declare inputs
std
::
vector
<
nvinfer1
::
ITensor
*>
input_ids
;
// Declare inputs_weight
...
...
@@ -95,55 +75,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
if
(
flag_varseqlen
)
{
engine_
->
SetITensor
(
"pos_id"
,
engine_
->
GetITensor
(
pos_id_name
));
engine_
->
SetITensor
(
"mask_id"
,
engine_
->
GetITensor
(
mask_id_name
));
auto
mask_id_tensor
=
engine_
->
GetITensor
(
"mask_id"
);
auto
mask_dims
=
mask_id_tensor
->
getDimensions
();
auto
slice_start_dims
=
mask_dims
;
auto
slice_stride_dims
=
mask_dims
;
for
(
int
i
=
0
;
i
<
mask_dims
.
nbDims
;
i
++
)
{
slice_start_dims
.
d
[
i
]
=
0
;
slice_stride_dims
.
d
[
i
]
=
1
;
}
auto
*
shape_tensor
=
Shape
(
mask_id_tensor
);
std
::
vector
<
nvinfer1
::
ITensor
*>
size_vec_tensor
;
std
::
vector
<
nvinfer1
::
ITensor
*>
start_vec_tensor
;
for
(
int
i
=
0
;
i
<
mask_dims
.
nbDims
;
i
++
)
{
size_vec_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
start_vec_tensor
.
push_back
(
Add1DConstantLayer
(
0
));
}
size_vec_tensor
[
1
]
=
GetEleTensorOfShape
(
shape_tensor
,
1
);
auto
size_tensor
=
Concat
(
size_vec_tensor
);
auto
start_tensor
=
Concat
(
start_vec_tensor
);
auto
slice_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
mask_id_tensor
,
slice_start_dims
,
slice_start_dims
,
slice_stride_dims
);
// unuseful slice_start_dims
slice_layer
->
setInput
(
1
,
*
start_tensor
);
slice_layer
->
setInput
(
2
,
*
size_tensor
);
slice_layer
->
setName
(
(
"Embeltwise_slice_layer (Output: slice_max_seqlen "
+
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
.
c_str
());
engine_
->
SetTensorDynamicRange
(
slice_layer
->
getOutput
(
0
),
1.0
f
);
auto
*
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
slice_layer
->
getOutput
(
0
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
d
[
0
]
=
-
1
;
reshape_layer
->
setReshapeDimensions
(
shape_dim
);
reshape_layer
->
setName
((
"Embeltwise_reshape_layer (Output: max_seqlen "
+
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
.
c_str
());
engine_
->
SetTensorDynamicRange
(
reshape_layer
->
getOutput
(
0
),
1.0
f
);
engine_
->
SetITensor
(
"max_seqlen_tensor"
,
reshape_layer
->
getOutput
(
0
));
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
input_tensor
=
engine_
->
GetITensor
(
id_names
[
i
]);
weight
=
GetWeight
(
emb_names
[
i
],
&
emb_dims
);
...
...
@@ -156,7 +87,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs
.
push_back
(
weight
.
get
());
emb_sizes
.
push_back
(
weight
.
get
().
count
);
}
hidden
=
emb_dims
[
1
];
}
bias_weight
=
GetWeight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetWeight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
...
...
@@ -206,26 +136,29 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_ptr
->
fields
=
fields
.
data
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
=
input_ids
;
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"mask_id"
));
// input mask_id
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"1"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNorm
PluginDynamic"
,
plugin_ptr
);
"ManyEmbLayerNorm
Varlen
PluginDynamic"
,
"1"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormVarlen
PluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
plugin_layer
->
setName
((
"ManyEmbLayerNorm
PluginDynamic_
V1(Output: "
+
plugin_layer
->
setName
((
"ManyEmbLayerNorm
VarlenPluginDynamic
V1(Output: "
+
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
.
c_str
());
free
(
plugin_ptr
);
if
(
enable_int8
)
{
float
out_scale
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
0
),
out_scale
);
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
1
),
out_scale
);
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
0
),
out_scale
);
// output
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
1
),
out_scale
);
// mask
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
2
),
out_scale
);
// max seqlen
}
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused emb_eltwise_layernorm op: use_varseqlen and "
...
...
@@ -249,54 +182,82 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"ManyEmbLayerNormPluginDynamic_V1"
,
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
),
std
::
string
(
"max_seqlen_tensor"
)},
test_mode
);
}
}
else
{
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
if
(
with_fp16
)
{
weight
=
GetFp16Weight
(
emb_names
[
i
],
&
emb_dims
);
}
else
{
weight
=
GetFp32Weight
(
emb_names
[
i
],
&
emb_dims
);
}
input_ids
.
push_back
(
engine_
->
GetITensor
(
id_names
[
i
]));
auto
input_tensor
=
engine_
->
GetITensor
(
id_names
[
i
]);
weight
=
GetWeight
(
emb_names
[
i
],
&
emb_dims
);
input_ids
.
push_back
(
input_tensor
);
input_embs
.
push_back
(
weight
.
get
());
emb_sizes
.
push_back
(
weight
.
get
().
count
);
hidden
=
emb_dims
[
1
];
}
if
(
with_fp16
)
{
bias_weight
=
GetFp16Weight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetFp16Weight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
}
else
{
bias_weight
=
GetFp32Weight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetFp32Weight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
// hidden = emb_dims[1];
}
bias_weight
=
GetWeight
(
op_desc
.
Input
(
"Bias"
).
front
(),
&
bias_dims
);
scale_weight
=
GetWeight
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
bias_size
=
phi
::
product
(
bias_dims
);
scale_size
=
phi
::
product
(
scale_dims
);
float
eps
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
std
::
vector
<
void
*>
input_embs_data
;
for
(
size_t
i
=
0
;
i
<
input_embs
.
size
();
++
i
)
{
input_embs_data
.
push_back
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
input_embs
[
i
].
values
)));
int
output_fp16
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
1
:
0
);
if
(
enable_int8
)
{
output_fp16
=
1
;
}
std
::
vector
<
nvinfer1
::
PluginField
>
fields
;
std
::
vector
<
std
::
string
>
temp_fields_keys
;
fields
.
emplace_back
(
"bert_embeddings_layernorm_beta"
,
bias_weight
.
get
().
values
,
GetPluginFieldType
(
bias_weight
.
get
().
type
),
static_cast
<
int32_t
>
(
bias_size
));
fields
.
emplace_back
(
"bert_embeddings_layernorm_gamma"
,
scale_weight
.
get
().
values
,
GetPluginFieldType
(
scale_weight
.
get
().
type
),
static_cast
<
int32_t
>
(
scale_size
));
fields
.
emplace_back
(
"output_fp16"
,
&
output_fp16
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
);
for
(
int
i
=
0
;
i
<
input_num
;
++
i
)
{
temp_fields_keys
.
push_back
(
"bert_embeddings_word_embeddings_"
+
std
::
to_string
(
i
));
fields
.
emplace_back
(
temp_fields_keys
.
rbegin
()
->
c_str
(),
input_embs
[
i
].
values
,
GetPluginFieldType
(
input_embs
[
i
].
type
),
static_cast
<
int32_t
>
(
emb_sizes
[
i
]));
}
nvinfer1
::
PluginFieldCollection
*
plugin_ptr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_ptr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
plugin_ptr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_ptr
->
fields
=
fields
.
data
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
=
input_ids
;
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"1"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormPluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
plugin_layer
->
setName
((
"ManyEmbLayerNormPluginDynamicV1(Output: "
+
op_desc
.
Output
(
"Out"
)[
0
]
+
")"
)
.
c_str
());
free
(
plugin_ptr
);
if
(
enable_int8
)
{
float
out_scale
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
engine_
->
SetTensorDynamicRange
(
plugin_layer
->
getOutput
(
0
),
out_scale
);
// output
}
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
(
input_embs_data
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias_weight
.
get
().
values
)),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale_weight
.
get
().
values
)),
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
eps
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
input_ids
.
data
(),
input_num
,
plugin
);
layer
=
plugin_layer
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"
emb_eltwise_layernorm
"
,
{
output_name
},
test_mode
);
layer
,
"
ManyEmbLayerNormPluginDynamicV1
"
,
{
output_name
},
test_mode
);
}
}
};
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
22bfa579
...
...
@@ -194,10 +194,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor"
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"ManyEmbLayerNormPluginDynamic"
,
"2"
);
"ManyEmbLayerNorm
Varlen
PluginDynamic"
,
"2"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNorm
PluginDynamic"
,
plugin_ptr
);
auto
plugin_obj
=
creator
->
createPlugin
(
"ManyEmbLayerNormVarlen
PluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
22bfa579
...
...
@@ -11,7 +11,6 @@ list(
group_norm_op_plugin.cu
layer_norm_op_plugin.cu
instance_norm_op_plugin.cu
emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu
skip_layernorm_op_plugin.cu
hard_swish_op_plugin.cu
...
...
@@ -38,12 +37,14 @@ list(
merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu
generic_plugin.cu
lookup_table.cu
)
lookup_table.cu
many_emb_layernorm_plugin.cu
many_emb_Layernorm_kernel.cu
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
AND NOT WIN32
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
list
(
APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernel
MT
ron.cu
many_emb_Layernorm_varseqlen_kernel
HF
ace.cu
)
many_emb_Layernorm_varseqlen_kernel
_mt
ron.cu
many_emb_Layernorm_varseqlen_kernel
_hf
ace.cu
)
endif
()
if
(
CUSPARSELT_FOUND AND
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 8
)
...
...
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
已删除
100644 → 0
浏览文件 @
8c416653
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <type_traits>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
// Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000)
template
<
typename
T
>
void
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
)
{
auto
*
ptr
=
dynamic_cast
<
const
EmbEltwiseLayernormPluginDynamicImpl
<
T
>
*>
(
anthor
);
if
(
!
ptr
->
is_initialized_
)
{
return
;
}
embs_gpu_
=
ptr
->
embs_gpu_
;
scale_gpu_
=
ptr
->
scale_gpu_
;
bias_gpu_
=
ptr
->
bias_gpu_
;
int
input_num
=
embs_
.
size
();
in_ptr_tensor_
.
Resize
({
input_num
});
emb_ptr_tensor_
.
ShareDataWith
(
ptr
->
emb_ptr_tensor_
);
}
template
<
typename
T
>
int
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
initialize
()
{
if
(
is_initialized_
)
{
return
0
;
}
embs_gpu_
.
resize
(
embs_
.
size
());
for
(
int
i
=
0
;
i
<
embs_
.
size
();
i
++
)
{
if
(
embs_
[
i
])
{
T
*
host_ptr
=
embs_
[
i
];
auto
size
=
emb_sizes_
[
i
];
cudaMalloc
(
&
embs_gpu_
[
i
],
sizeof
(
T
)
*
size
);
cudaMemcpy
(
embs_gpu_
[
i
],
host_ptr
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
}
if
(
bias_
)
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
T
)
*
bias_size_
);
cudaMemcpy
(
bias_gpu_
,
bias_
,
bias_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
if
(
scale_
)
{
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
T
)
*
scale_size_
);
cudaMemcpy
(
scale_gpu_
,
scale_
,
scale_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
int
input_num
=
embs_
.
size
();
in_ptr_tensor_
.
Resize
({
input_num
});
emb_ptr_tensor_
.
Resize
({
input_num
});
cudaGetDevice
(
&
device_id_
);
auto
emb_ptr_gpu_d
=
emb_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id_
));
cudaMemcpy
(
emb_ptr_gpu_d
,
embs_gpu_
.
data
(),
sizeof
(
uintptr_t
)
*
input_num
,
cudaMemcpyHostToDevice
);
is_initialized_
=
true
;
return
0
;
}
template
<
typename
T
>
void
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
terminate
()
{
for
(
int
i
=
0
;
i
<
embs_gpu_
.
size
();
++
i
)
{
if
(
embs_gpu_
[
i
])
{
cudaFree
(
embs_gpu_
[
i
]);
embs_gpu_
[
i
]
=
nullptr
;
}
}
if
(
bias_gpu_
)
{
cudaFree
(
bias_gpu_
);
bias_gpu_
=
nullptr
;
}
if
(
scale_gpu_
)
{
cudaFree
(
scale_gpu_
);
scale_gpu_
=
nullptr
;
}
}
template
<
typename
T
>
int
EmbEltwiseLayernormPluginDynamicImpl
<
T
>::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
id_dims
=
input_desc
[
0
].
dims
;
int
batch
=
id_dims
.
d
[
0
];
int
seq_len
=
id_dims
.
d
[
1
];
int
input_num
=
embs_
.
size
();
cudaGetDevice
(
&
device_id_
);
auto
in_ptr_gpu_d
=
in_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id_
));
auto
emb_ptr_gpu_d
=
emb_ptr_tensor_
.
mutable_data
<
int64_t
>
(
platform
::
CUDAPlace
(
device_id_
));
cudaMemcpyAsync
(
in_ptr_gpu_d
,
reinterpret_cast
<
const
void
*>
(
inputs
),
sizeof
(
uintptr_t
)
*
input_num
,
cudaMemcpyHostToDevice
,
stream
);
auto
out_type
=
output_desc
[
0
].
type
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
PADDLE_ENFORCE_EQ
(
out_type
==
nvinfer1
::
DataType
::
kFLOAT
,
true
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only support fp32 input."
));
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
PADDLE_ENFORCE_EQ
(
out_type
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only support fp16 input."
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."
));
}
auto
*
output_d
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
operators
::
math
::
EmbEltwiseLayerNormFunctor
<
T
>
emb_eltwise_layernorm_func
;
emb_eltwise_layernorm_func
(
batch
,
seq_len
,
hidden_size_
,
in_ptr_gpu_d
,
scale_gpu_
,
bias_gpu_
,
emb_ptr_gpu_d
,
output_d
,
eps_
,
input_num
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
template
class
EmbEltwiseLayernormPluginDynamicImpl
<
float
>;
#ifdef TRT_PLUGIN_FP16_AVALIABLE
template
class
EmbEltwiseLayernormPluginDynamicImpl
<
half
>;
#endif
int
EmbEltwiseLayernormPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
impl_
->
initialize
();
return
0
;
}
void
EmbEltwiseLayernormPluginDynamic
::
terminate
()
TRT_NOEXCEPT
{
impl_
->
terminate
();
}
nvinfer1
::
DimsExprs
EmbEltwiseLayernormPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
// NOLINT
PADDLE_ENFORCE_EQ
(
output_index
,
0
,
platform
::
errors
::
InvalidArgument
(
"There is only one output of the EmbEltwiseLayernorm, "
"so the index should be zero,"
"but it's (%d)"
,
output_index
));
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
3
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
inputs
[
0
].
d
[
1
];
ret
.
d
[
2
]
=
expr_builder
.
constant
(
hidden_size_
);
return
ret
;
}
bool
EmbEltwiseLayernormPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_EQ
(
nb_outputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs."
,
nb_outputs
));
int
all_nums
=
nb_inputs
+
nb_outputs
;
PADDLE_ENFORCE_LT
(
pos
,
all_nums
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
all_nums
));
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
if
(
desc
.
format
!=
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
return
false
;
}
if
(
pos
==
0
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
;
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
if
(
pos
<
all_nums
-
1
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
desc
.
dims
.
d
[
0
]
==
prev
.
dims
.
d
[
0
]
&&
desc
.
dims
.
d
[
1
]
==
prev
.
dims
.
d
[
1
];
}
// output
if
(
pos
==
all_nums
-
1
)
{
if
(
with_fp16_
==
false
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
;
}
else
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
;
}
}
return
false
;
}
nvinfer1
::
DataType
EmbEltwiseLayernormPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The EmbEltwiseLayernorm Plugin only has one output, so the "
"index value should be 0, but get %d."
,
index
));
if
(
with_fp16_
)
return
nvinfer1
::
DataType
::
kHALF
;
else
return
nvinfer1
::
DataType
::
kFLOAT
;
}
int
EmbEltwiseLayernormPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
impl_
->
enqueue
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h
已删除
100644 → 0
浏览文件 @
8c416653
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstddef>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
EmbEltwiseLayernormPluginDynamicImplBase
{
public:
EmbEltwiseLayernormPluginDynamicImplBase
()
{}
virtual
~
EmbEltwiseLayernormPluginDynamicImplBase
()
{}
virtual
int
initialize
()
=
0
;
virtual
void
terminate
()
=
0
;
virtual
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
virtual
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
)
=
0
;
};
template
<
typename
T
>
class
EmbEltwiseLayernormPluginDynamicImpl
:
public
EmbEltwiseLayernormPluginDynamicImplBase
{
public:
explicit
EmbEltwiseLayernormPluginDynamicImpl
(
std
::
vector
<
T
*>
input_embs
,
T
*
bias
,
T
*
scale
,
std
::
vector
<
int
>
emb_sizes
,
int
bias_size
,
int
scale_size
,
int
hidden_size
,
float
eps
)
:
embs_
(
input_embs
),
bias_
(
bias
),
scale_
(
scale
),
emb_sizes_
(
emb_sizes
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
hidden_size_
(
hidden_size
),
eps_
(
eps
)
{}
~
EmbEltwiseLayernormPluginDynamicImpl
()
{}
int
initialize
();
void
terminate
();
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
;
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamicImplBase
*
anthor
);
private:
std
::
vector
<
T
*>
embs_
;
T
*
bias_
{
nullptr
};
T
*
scale_
{
nullptr
};
// data on devices
T
*
bias_gpu_
{
nullptr
};
T
*
scale_gpu_
{
nullptr
};
std
::
vector
<
T
*>
embs_gpu_
;
std
::
vector
<
int
>
emb_sizes_
;
int
bias_size_
;
int
scale_size_
;
int
hidden_size_
;
float
eps_
;
phi
::
DenseTensor
in_ptr_tensor_
,
emb_ptr_tensor_
;
int
device_id_
{
0
};
bool
is_initialized_
{
false
};
};
class
EmbEltwiseLayernormPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
EmbEltwiseLayernormPluginDynamic
(
std
::
vector
<
void
*>
input_embs
,
void
*
bias
,
void
*
scale
,
std
::
vector
<
int
>
emb_sizes
,
int
bias_size
,
int
scale_size
,
int
hidden_size
,
float
eps
,
bool
with_fp16
)
:
embs_
(
input_embs
),
bias_
(
bias
),
scale_
(
scale
),
emb_sizes_
(
emb_sizes
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
hidden_size_
(
hidden_size
),
eps_
(
eps
),
own_host_buff_
(
false
)
{
with_fp16_
=
with_fp16
;
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG
(
1
)
<<
"TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16"
;
instantiateImpl
<
half
>
();
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
}
else
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32"
;
instantiateImpl
<
float
>
();
}
}
EmbEltwiseLayernormPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
:
own_host_buff_
(
true
)
{
// the first var is with_fp16, we will use it.
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
emb_sizes_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
embs_
.
resize
(
emb_sizes_
.
size
());
if
(
with_fp16_
)
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
auto
ptr
=
new
half
[
size
];
memcpy
(
ptr
,
serial_data
,
sizeof
(
half
)
*
size
);
embs_
[
i
]
=
ptr
;
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
size
*
sizeof
(
half
);
serial_length
-=
size
*
sizeof
(
half
);
}
if
(
bias_size_
)
{
bias_
=
new
half
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
half
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
half
);
serial_length
-=
bias_size_
*
sizeof
(
half
);
if
(
scale_size_
)
{
scale_
=
new
half
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
half
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
half
);
serial_length
-=
scale_size_
*
sizeof
(
half
);
}
else
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
auto
ptr
=
new
float
[
size
];
memcpy
(
ptr
,
serial_data
,
sizeof
(
float
)
*
size
);
embs_
[
i
]
=
ptr
;
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
size
*
sizeof
(
float
);
serial_length
-=
size
*
sizeof
(
float
);
}
if
(
bias_size_
)
{
bias_
=
new
float
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
float
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
float
);
serial_length
-=
bias_size_
*
sizeof
(
float
);
if
(
scale_size_
)
{
scale_
=
new
float
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
float
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
float
);
serial_length
-=
scale_size_
*
sizeof
(
float
);
}
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
hidden_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
instantiateImpl
<
half
>
();
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
}
else
{
instantiateImpl
<
float
>
();
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
ptr
=
new
EmbEltwiseLayernormPluginDynamic
(
embs_
,
bias_
,
scale_
,
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
,
with_fp16_
);
ptr
->
shareGPUData
(
this
);
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"fused_embedding_eltwise_layernorm_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
int
sum_num
=
0
;
sum_num
+=
SerializedSize
(
with_fp16_
);
sum_num
+=
SerializedSize
(
emb_sizes_
);
if
(
with_fp16_
)
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
sum_num
+=
emb_sizes_
[
i
]
*
sizeof
(
half
);
}
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
half
);
}
else
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
sum_num
+=
emb_sizes_
[
i
]
*
sizeof
(
float
);
}
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
float
);
}
sum_num
+=
SerializedSize
(
bias_size_
);
sum_num
+=
SerializedSize
(
scale_size_
);
sum_num
+=
SerializedSize
(
hidden_size_
);
sum_num
+=
SerializedSize
(
eps_
);
return
sum_num
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
// the first var is for with_fp16, we will use it later;
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
emb_sizes_
);
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
if
(
with_fp16_
)
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
embs_
[
i
])[
j
]);
}
}
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
scale_
)[
i
]);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
emb_sizes_
.
size
();
i
++
)
{
auto
size
=
emb_sizes_
[
i
];
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
embs_
[
i
])[
j
]);
}
}
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
scale_
)[
i
]);
}
}
SerializeValue
(
&
buffer
,
hidden_size_
);
SerializeValue
(
&
buffer
,
eps_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
if
(
own_host_buff_
)
{
if
(
with_fp16_
)
{
for
(
auto
ptr
:
embs_
)
{
delete
[]
reinterpret_cast
<
half
*>
(
ptr
);
}
delete
[]
reinterpret_cast
<
half
*>
(
bias_
);
delete
[]
reinterpret_cast
<
half
*>
(
scale_
);
}
else
{
for
(
auto
ptr
:
embs_
)
{
delete
[]
reinterpret_cast
<
float
*>
(
ptr
);
}
delete
[]
reinterpret_cast
<
float
*>
(
bias_
);
delete
[]
reinterpret_cast
<
float
*>
(
scale_
);
}
}
delete
impl_
;
delete
this
;
}
private:
std
::
vector
<
void
*>
embs_
;
void
*
bias_
{
nullptr
};
void
*
scale_
{
nullptr
};
std
::
vector
<
int
>
emb_sizes_
;
int
bias_size_
;
int
scale_size_
;
int
hidden_size_
;
float
eps_
;
bool
own_host_buff_
{
false
};
EmbEltwiseLayernormPluginDynamicImplBase
*
impl_
{
nullptr
};
void
shareGPUData
(
const
EmbEltwiseLayernormPluginDynamic
*
anthor
)
{
impl_
->
shareGPUData
(
anthor
->
impl_
);
}
template
<
typename
U
>
void
instantiateImpl
()
{
std
::
vector
<
U
*>
embs
;
embs
.
resize
(
embs_
.
size
());
for
(
size_t
i
=
0
;
i
<
embs_
.
size
();
++
i
)
{
embs
[
i
]
=
reinterpret_cast
<
U
*>
(
embs_
[
i
]);
}
impl_
=
new
EmbEltwiseLayernormPluginDynamicImpl
<
U
>
(
embs
,
reinterpret_cast
<
U
*>
(
bias_
),
reinterpret_cast
<
U
*>
(
scale_
),
emb_sizes_
,
bias_size_
,
scale_size_
,
hidden_size_
,
eps_
);
}
};
class
EmbEltwiseLayernormPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
EmbEltwiseLayernormPluginDynamicCreator
()
{}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"fused_embedding_eltwise_layernorm_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
EmbEltwiseLayernormPluginDynamic
(
serial_data
,
serial_length
);
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
;
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
EmbEltwiseLayernormPluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu
0 → 100644
浏览文件 @
22bfa579
此差异已折叠。
点击以展开。
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel
HF
ace.cu
→
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel
_hf
ace.cu
浏览文件 @
22bfa579
...
...
@@ -33,7 +33,6 @@ template <typename T, unsigned TPB>
__global__
void
embLayerNormKernelHFace_2
(
int32_t
ld
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
...
...
@@ -93,7 +92,6 @@ __global__ void embLayerNormKernelHFace_3(int32_t ld,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
...
...
@@ -168,7 +166,6 @@ __global__ void embLayerNormKernelHFace_4(int32_t ld,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
const
*
inputIds3
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
...
...
@@ -273,7 +270,6 @@ int32_t embSkipLayerNormHFace_2(cudaStream_t stream,
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
...
...
@@ -311,7 +307,6 @@ int32_t embSkipLayerNormHFace_3(cudaStream_t stream,
inputIds0
,
inputIds1
,
inputIds2
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
...
...
@@ -355,7 +350,6 @@ int32_t embSkipLayerNormHFace_4(cudaStream_t stream,
inputIds1
,
inputIds2
,
inputIds3
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel
MT
ron.cu
→
paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel
_mt
ron.cu
浏览文件 @
22bfa579
...
...
@@ -33,7 +33,6 @@ template <typename T, unsigned TPB>
__global__
void
embLayerNormKernelMTron_2
(
int32_t
ld
,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
...
...
@@ -95,7 +94,6 @@ __global__ void embLayerNormKernelMTron_3(int32_t ld,
int32_t
const
*
inputIds0
,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
...
...
@@ -172,7 +170,6 @@ __global__ void embLayerNormKernelMTron_4(int32_t ld,
int32_t
const
*
inputIds1
,
int32_t
const
*
inputIds2
,
int32_t
const
*
inputIds3
,
int32_t
nbLookupTables
,
float
const
*
beta
,
float
const
*
gamma
,
T
const
*
mIdsEmbDev0
,
...
...
@@ -280,7 +277,6 @@ int32_t embSkipLayerNormMTron_2(cudaStream_t stream,
<<<
grid
,
block
,
cache_size
,
stream
>>>
(
ld
,
inputIds0
,
inputIds1
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
...
...
@@ -320,7 +316,6 @@ int32_t embSkipLayerNormMTron_3(cudaStream_t stream,
inputIds0
,
inputIds1
,
inputIds2
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
...
...
@@ -366,7 +361,6 @@ int32_t embSkipLayerNormMTron_4(cudaStream_t stream,
inputIds1
,
inputIds2
,
inputIds3
,
nbLookupTables
,
beta
,
gamma
,
mIdsEmbDev0
,
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu
0 → 100644
浏览文件 @
22bfa579
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h"
#include <cuda.h>
#include <cstring>
#include <vector>
#include "NvInfer.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
constexpr
size_t
threadsPerCta128
=
2
*
2
*
32
;
constexpr
size_t
threadsPerCta256
=
1
*
4
*
32
;
constexpr
size_t
threadsPerCta384
=
1
*
8
*
32
;
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M
// dimension: (s + 16*warps_m - 1) / (16*warps_m);
constexpr
size_t
xmmasM128
=
4
;
constexpr
size_t
xmmasM256
=
16
;
constexpr
size_t
xmmasM384
=
24
;
// Packed mask size per batch. Layout is XMMAS_M * THREADS_PER_CTA.
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
constexpr
size_t
packedMaskSize256
=
xmmasM256
*
threadsPerCta256
;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
char
const
*
EMB_LAYER_NORM_VERSION
{
"1"
};
char
const
*
EMB_LAYER_NORM_NAME
{
"ManyEmbLayerNormPluginDynamic"
};
// Static class fields initialization
nvinfer1
::
PluginFieldCollection
EmbLayerNormPluginCreator
::
mFC
{};
std
::
vector
<
nvinfer1
::
PluginField
>
EmbLayerNormPluginCreator
::
mPluginAttributes
;
EmbLayerNormPlugin
::
EmbLayerNormPlugin
(
std
::
string
const
&
name
,
nvinfer1
::
DataType
const
type
,
nvinfer1
::
Weights
const
&
beta
,
nvinfer1
::
Weights
const
&
gamma
,
const
std
::
vector
<
nvinfer1
::
Weights
>&
IdsEmb
)
:
mLayerName
(
name
),
mLd
(
beta
.
count
),
mType
(
type
),
mIdsEmb_
(
IdsEmb
),
nbLookupTables_
(
static_cast
<
int
>
(
IdsEmb
.
size
()))
{
// Assuming Weights.count is the number of elements and not bytes
assert
(
beta
.
count
==
gamma
.
count
);
mBeta
.
convertAndCopy
(
beta
,
nvinfer1
::
DataType
::
kFLOAT
);
mGamma
.
convertAndCopy
(
gamma
,
nvinfer1
::
DataType
::
kFLOAT
);
copyToDevice
(
&
mGamma
,
sizeof
(
float
)
*
mGamma
.
count
,
&
mGammaDev
);
copyToDevice
(
&
mBeta
,
sizeof
(
float
)
*
mBeta
.
count
,
&
mBetaDev
);
for
(
size_t
i
=
0
;
i
<
mIdsEmb_
.
size
();
++
i
)
{
assert
(
mIdsEmb_
[
i
].
count
%
mLd
==
0
);
mIdsVocabSize
.
push_back
(
int32_t
(
mIdsEmb_
[
i
].
count
/
mLd
));
WeightsWithOwnership
tem_weight
;
tem_weight
.
convertAndCopy
(
mIdsEmb_
[
i
],
mType
);
void
*
cudaMem
{
nullptr
};
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
cudaMem
,
getWeightsSize
(
tem_weight
,
mType
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
cudaMem
,
tem_weight
.
values
,
getWeightsSize
(
tem_weight
,
mType
),
cudaMemcpyHostToDevice
));
mIdsEmbPtrs
.
push_back
(
cudaMem
);
}
}
EmbLayerNormPlugin
::
EmbLayerNormPlugin
(
std
::
string
const
&
name
,
void
const
*
data
,
size_t
length
)
:
mLayerName
(
name
),
mGammaDev
(
nullptr
),
mBetaDev
(
nullptr
),
mIdsEmbPtrs
{},
mIdsEmb_
{}
{
// Deserialize in the same order as serialization
deserialize_value
(
&
data
,
&
length
,
&
mType
);
deserialize_value
(
&
data
,
&
length
,
&
mLd
);
deserialize_value
(
&
data
,
&
length
,
&
nbLookupTables_
);
for
(
int32_t
i
=
0
;
i
<
nbLookupTables_
;
++
i
)
{
int32_t
tem
;
deserialize_value
(
&
data
,
&
length
,
&
tem
);
mIdsVocabSize
.
push_back
(
tem
);
}
char
const
*
d
=
static_cast
<
char
const
*>
(
data
);
mBeta
.
convertAndCopy
(
&
d
,
mLd
,
nvinfer1
::
DataType
::
kFLOAT
);
mGamma
.
convertAndCopy
(
&
d
,
mLd
,
nvinfer1
::
DataType
::
kFLOAT
);
for
(
int32_t
i
=
0
;
i
<
nbLookupTables_
;
++
i
)
{
nvinfer1
::
Weights
pre_tem_weight
;
pre_tem_weight
.
type
=
mType
;
pre_tem_weight
.
count
=
mLd
*
size_t
(
mIdsVocabSize
[
i
]);
const
auto
nbBytes
=
mLd
*
size_t
(
mIdsVocabSize
[
i
])
*
getElementSize
(
mType
);
auto
destBuf
=
new
char
[
nbBytes
];
pre_tem_weight
.
values
=
destBuf
;
std
::
copy_n
(
d
,
nbBytes
,
destBuf
);
d
+=
nbBytes
;
mIdsEmb_
.
push_back
(
pre_tem_weight
);
}
}
// IPluginV2DynamicExt Methods
nvinfer1
::
IPluginV2DynamicExt
*
EmbLayerNormPlugin
::
clone
()
const
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormPlugin clone"
);
auto
p
=
new
EmbLayerNormPlugin
(
mLayerName
,
mType
,
mBeta
,
mGamma
,
mIdsEmb_
);
p
->
setPluginNamespace
(
mNamespace
.
c_str
());
return
p
;
}
nvinfer1
::
DimsExprs
EmbLayerNormPlugin
::
getOutputDimensions
(
int32_t
outputIndex
,
nvinfer1
::
DimsExprs
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
noexcept
{
assert
(
outputIndex
==
0
);
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
3
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
inputs
[
0
].
d
[
1
];
ret
.
d
[
2
]
=
exprBuilder
.
constant
(
mLd
);
return
ret
;
}
bool
EmbLayerNormPlugin
::
supportsFormatCombination
(
int32_t
pos
,
nvinfer1
::
PluginTensorDesc
const
*
inOut
,
int32_t
nbInputs
,
int32_t
nbOutputs
)
noexcept
{
assert
(
nbOutputs
==
1
);
nvinfer1
::
PluginTensorDesc
const
&
prev
=
inOut
[
0
];
nvinfer1
::
PluginTensorDesc
const
&
desc
=
inOut
[
pos
];
if
(
desc
.
format
!=
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
return
false
;
}
if
(
pos
==
0
)
{
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
;
}
if
(
0
<
pos
&&
pos
<
nbInputs
)
{
assert
(
desc
.
dims
.
nbDims
==
prev
.
dims
.
nbDims
);
for
(
int
i
=
0
;
i
<
prev
.
dims
.
nbDims
;
++
i
)
{
assert
(
desc
.
dims
.
d
[
i
]
==
prev
.
dims
.
d
[
i
]);
}
return
desc
.
type
==
prev
.
type
;
}
if
(
pos
==
nbInputs
)
{
// output
return
desc
.
type
==
mType
&&
desc
.
dims
.
nbDims
==
3
&&
desc
.
dims
.
d
[
0
]
==
prev
.
dims
.
d
[
0
]
&&
desc
.
dims
.
d
[
1
]
==
prev
.
dims
.
d
[
1
];
}
}
void
EmbLayerNormPlugin
::
configurePlugin
(
nvinfer1
::
DynamicPluginTensorDesc
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
DynamicPluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormPlugin configurePlugin"
);
assert
(
static_cast
<
size_t
>
(
outputs
[
0
].
desc
.
dims
.
d
[
2
])
==
static_cast
<
size_t
>
(
mLd
));
int32_t
const
B
=
inputs
[
0
].
desc
.
dims
.
d
[
0
];
if
(
B
>
0
)
{
assert
(
outputs
[
0
].
desc
.
dims
.
d
[
0
]
==
B
);
}
assert
(
outputs
[
0
].
desc
.
type
==
mType
);
}
size_t
EmbLayerNormPlugin
::
getWorkspaceSize
(
nvinfer1
::
PluginTensorDesc
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
PluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
const
noexcept
{
return
0
;
}
int32_t
EmbLayerNormPlugin
::
enqueue
(
nvinfer1
::
PluginTensorDesc
const
*
inputDesc
,
nvinfer1
::
PluginTensorDesc
const
*
outputDesc
,
void
const
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
{
int32_t
batchSize
=
inputDesc
[
0
].
dims
.
d
[
0
];
int32_t
const
maxSeqlen
=
inputDesc
[
0
].
dims
.
d
[
1
];
if
(
maxSeqlen
>
512
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"EmbLayerNormPlugin support maxSeqlen is 512"
));
}
const
float
*
beta
=
mBetaDev
.
get
();
const
float
*
gamma
=
mGammaDev
.
get
();
if
(
mType
==
nvinfer1
::
DataType
::
kFLOAT
)
{
auto
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
if
(
nbLookupTables_
==
2
)
{
return
embSkipLayerNorm_2
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
maxSeqlen
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNorm_3
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
maxSeqlen
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNorm_4
<
float
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
maxSeqlen
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
float
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
if
(
mType
==
nvinfer1
::
DataType
::
kHALF
)
{
auto
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
if
(
nbLookupTables_
==
2
)
{
return
embSkipLayerNorm_2
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
maxSeqlen
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
output
);
}
else
if
(
nbLookupTables_
==
3
)
{
return
embSkipLayerNorm_3
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
maxSeqlen
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
output
);
}
else
if
(
nbLookupTables_
==
4
)
{
return
embSkipLayerNorm_4
<
half
>
(
stream
,
static_cast
<
int32_t
>
(
mLd
),
batchSize
,
maxSeqlen
,
static_cast
<
int32_t
const
*>
(
inputs
[
0
]),
static_cast
<
int32_t
const
*>
(
inputs
[
1
]),
static_cast
<
int32_t
const
*>
(
inputs
[
2
]),
static_cast
<
int32_t
const
*>
(
inputs
[
3
]),
nbLookupTables_
,
beta
,
gamma
,
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
0
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
1
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
2
]),
static_cast
<
half
const
*>
(
mIdsEmbPtrs
[
3
]),
mIdsVocabSize
[
0
],
mIdsVocabSize
[
1
],
mIdsVocabSize
[
2
],
mIdsVocabSize
[
3
],
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support 2,3,4 lookup_tables fused "
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported type error, expected [kHALF,kFLOAT]"
));
}
return
STATUS_SUCCESS
;
}
// IPluginV2Ext Methods
nvinfer1
::
DataType
EmbLayerNormPlugin
::
getOutputDataType
(
int32_t
index
,
nvinfer1
::
DataType
const
*
inputTypes
,
int32_t
nbInputs
)
const
noexcept
{
assert
(
index
==
0
);
assert
(
mType
==
nvinfer1
::
DataType
::
kHALF
||
mType
==
nvinfer1
::
DataType
::
kFLOAT
);
return
mType
;
}
// IPluginV2 Methods
char
const
*
EmbLayerNormPlugin
::
getPluginType
()
const
noexcept
{
return
EMB_LAYER_NORM_NAME
;
}
char
const
*
EmbLayerNormPlugin
::
getPluginVersion
()
const
noexcept
{
return
EMB_LAYER_NORM_VERSION
;
}
int32_t
EmbLayerNormPlugin
::
getNbOutputs
()
const
noexcept
{
return
1
;
}
int32_t
EmbLayerNormPlugin
::
initialize
()
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormPlugin initialize"
);
return
0
;
}
void
EmbLayerNormPlugin
::
terminate
()
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormPlugin terminate"
);
}
size_t
EmbLayerNormPlugin
::
getSerializationSize
()
const
noexcept
{
size_t
const
wordSize
=
getElementSize
(
mType
);
return
2
*
sizeof
(
float
)
*
mLd
// beta + gamma
+
sizeof
(
mType
)
//
+
sizeof
(
mLd
)
//
+
mIdsVocabSize
.
size
()
*
sizeof
(
mIdsVocabSize
[
0
])
//
+
wordSize
*
mLd
*
accumulate
(
mIdsVocabSize
.
begin
(),
mIdsVocabSize
.
end
(),
0
)
// ids emb
+
sizeof
(
nbLookupTables_
);
// numbers of lookup_table
}
void
EmbLayerNormPlugin
::
serialize
(
void
*
buffer
)
const
noexcept
{
serialize_value
(
&
buffer
,
mType
);
serialize_value
(
&
buffer
,
mLd
);
serialize_value
(
&
buffer
,
nbLookupTables_
);
for
(
size_t
i
=
0
;
i
<
mIdsVocabSize
.
size
();
++
i
)
{
serialize_value
(
&
buffer
,
mIdsVocabSize
[
i
]);
}
char
*
d
=
static_cast
<
char
*>
(
buffer
);
size_t
const
wordSize
=
getElementSize
(
mType
);
serFromDev
(
&
d
,
mBetaDev
.
get
(),
mLd
);
serFromDev
(
&
d
,
mGammaDev
.
get
(),
mLd
);
for
(
size_t
i
=
0
;
i
<
mIdsEmbPtrs
.
size
();
++
i
)
{
serFromDev
(
&
d
,
static_cast
<
char
*>
(
mIdsEmbPtrs
[
i
]),
mLd
*
mIdsVocabSize
[
i
]
*
wordSize
);
}
}
void
EmbLayerNormPlugin
::
destroy
()
noexcept
{
// This gets called when the network containing plugin is destroyed
mBetaDev
.
reset
(
nullptr
);
mGammaDev
.
reset
(
nullptr
);
for
(
size_t
i
=
0
;
i
<
mIdsEmbPtrs
.
size
();
++
i
)
{
cudaFree
(
mIdsEmbPtrs
[
i
]);
}
delete
this
;
}
void
EmbLayerNormPlugin
::
setPluginNamespace
(
char
const
*
libNamespace
)
noexcept
{
mNamespace
=
libNamespace
;
}
char
const
*
EmbLayerNormPlugin
::
getPluginNamespace
()
const
noexcept
{
return
mNamespace
.
c_str
();
}
EmbLayerNormPluginCreator
::
EmbLayerNormPluginCreator
()
{}
char
const
*
EmbLayerNormPluginCreator
::
getPluginName
()
const
noexcept
{
return
EMB_LAYER_NORM_NAME
;
}
char
const
*
EmbLayerNormPluginCreator
::
getPluginVersion
()
const
noexcept
{
return
EMB_LAYER_NORM_VERSION
;
}
nvinfer1
::
PluginFieldCollection
const
*
EmbLayerNormPluginCreator
::
getFieldNames
()
noexcept
{
return
&
mFC
;
}
bool
initialize_fields
(
nvinfer1
::
PluginFieldCollection
const
*
fc
,
nvinfer1
::
Weights
*
beta
,
nvinfer1
::
Weights
*
gamma
,
std
::
vector
<
nvinfer1
::
Weights
>*
IdsEmb
)
{
bool
output_fp16
=
false
;
for
(
int32_t
i
=
0
;
i
<
fc
->
nbFields
;
i
++
)
{
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
if
(
field_name
.
compare
(
"bert_embeddings_layernorm_beta"
)
==
0
)
{
TRANSFORMER_DEBUG_MSG
(
"Building bert_embeddings_layernorm_beta..."
);
beta
->
values
=
fc
->
fields
[
i
].
data
;
beta
->
count
=
fc
->
fields
[
i
].
length
;
beta
->
type
=
fieldTypeToDataType
(
fc
->
fields
[
i
].
type
);
}
if
(
field_name
.
compare
(
"bert_embeddings_layernorm_gamma"
)
==
0
)
{
TRANSFORMER_DEBUG_MSG
(
"Building bert_embeddings_layernorm_gamma..."
);
gamma
->
values
=
fc
->
fields
[
i
].
data
;
gamma
->
count
=
fc
->
fields
[
i
].
length
;
gamma
->
type
=
fieldTypeToDataType
(
fc
->
fields
[
i
].
type
);
}
if
(
field_name
.
compare
(
"output_fp16"
)
==
0
)
{
TRANSFORMER_DEBUG_MSG
(
"Building output_fp16..."
);
assert
(
fc
->
fields
[
i
].
type
==
nvinfer1
::
PluginFieldType
::
kINT32
);
output_fp16
=
static_cast
<
int32_t
const
*>
(
fc
->
fields
[
i
].
data
)[
0
]
!=
0
;
}
if
(
field_name
.
compare
(
"bert_embeddings_word_embeddings_"
+
std
::
to_string
(
i
-
3
))
==
0
)
{
TRANSFORMER_DEBUG_MSG
(
(
"bert_embeddings_word_embeddings_"
+
std
::
to_string
(
i
-
3
)).
c_str
());
nvinfer1
::
Weights
tem
;
tem
.
values
=
fc
->
fields
[
i
].
data
;
tem
.
count
=
fc
->
fields
[
i
].
length
;
tem
.
type
=
fieldTypeToDataType
(
fc
->
fields
[
i
].
type
);
IdsEmb
->
push_back
(
tem
);
}
}
return
output_fp16
;
}
nvinfer1
::
IPluginV2
*
EmbLayerNormPluginCreator
::
createPlugin
(
char
const
*
name
,
nvinfer1
::
PluginFieldCollection
const
*
fc
)
noexcept
{
TRANSFORMER_DEBUG_MSG
(
"EmbLayerNormVar createPlugin"
);
nvinfer1
::
Weights
beta
;
nvinfer1
::
Weights
gamma
;
std
::
vector
<
nvinfer1
::
Weights
>
IdsEmb
;
bool
output_fp16
=
initialize_fields
(
fc
,
&
beta
,
&
gamma
,
&
IdsEmb
);
TRANSFORMER_DEBUG_MSG
(
"Building the Plugin..."
);
EmbLayerNormPlugin
*
p
=
new
EmbLayerNormPlugin
(
name
,
output_fp16
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
beta
,
gamma
,
IdsEmb
);
return
p
;
}
nvinfer1
::
IPluginV2
*
EmbLayerNormPluginCreator
::
deserializePlugin
(
char
const
*
name
,
void
const
*
serialData
,
size_t
serialLength
)
noexcept
{
return
new
EmbLayerNormPlugin
(
name
,
serialData
,
serialLength
);
}
void
EmbLayerNormPluginCreator
::
setPluginNamespace
(
char
const
*
libNamespace
)
noexcept
{
mNamespace
=
libNamespace
;
}
char
const
*
EmbLayerNormPluginCreator
::
getPluginNamespace
()
const
noexcept
{
return
mNamespace
.
c_str
();
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h
0 → 100644
浏览文件 @
22bfa579
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include "NvInferPlugin.h"
#include "NvInferRuntime.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
template
<
typename
T
>
int32_t
embSkipLayerNorm_2
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNorm_3
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
T
*
);
template
<
typename
T
>
int32_t
embSkipLayerNorm_4
(
cudaStream_t
,
int32_t
,
int32_t
,
int32_t
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
const
*
,
int32_t
,
float
const
*
,
float
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
T
const
*
,
int32_t
,
int32_t
,
int32_t
,
int32_t
,
T
*
);
class
EmbLayerNormPlugin
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
EmbLayerNormPlugin
(
std
::
string
const
&
name
,
nvinfer1
::
DataType
const
type
,
nvinfer1
::
Weights
const
&
beta
,
nvinfer1
::
Weights
const
&
gamma
,
const
std
::
vector
<
nvinfer1
::
Weights
>&
ids_emb
);
EmbLayerNormPlugin
(
std
::
string
const
&
name
,
void
const
*
data
,
size_t
length
);
EmbLayerNormPlugin
()
=
delete
;
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
noexcept
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int32_t
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
noexcept
override
;
void
configurePlugin
(
nvinfer1
::
DynamicPluginTensorDesc
const
*
in
,
int32_t
nbInputs
,
nvinfer1
::
DynamicPluginTensorDesc
const
*
out
,
int32_t
nbOutputs
)
noexcept
override
;
int32_t
enqueue
(
nvinfer1
::
PluginTensorDesc
const
*
inputDesc
,
nvinfer1
::
PluginTensorDesc
const
*
outputDesc
,
void
const
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
noexcept
override
;
int32_t
initialize
()
noexcept
override
;
void
terminate
()
noexcept
override
;
char
const
*
getPluginVersion
()
const
noexcept
override
;
bool
supportsFormatCombination
(
int32_t
pos
,
nvinfer1
::
PluginTensorDesc
const
*
inOut
,
int32_t
nbInputs
,
int32_t
nbOutputs
)
noexcept
override
;
size_t
getWorkspaceSize
(
nvinfer1
::
PluginTensorDesc
const
*
inputs
,
int32_t
nbInputs
,
nvinfer1
::
PluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
const
noexcept
override
;
nvinfer1
::
DataType
getOutputDataType
(
int32_t
index
,
nvinfer1
::
DataType
const
*
inputTypes
,
int32_t
nbInputs
)
const
noexcept
override
;
char
const
*
getPluginType
()
const
noexcept
override
;
int32_t
getNbOutputs
()
const
noexcept
override
;
size_t
getSerializationSize
()
const
noexcept
override
;
void
serialize
(
void
*
buffer
)
const
noexcept
override
;
void
destroy
()
noexcept
override
;
char
const
*
getPluginNamespace
()
const
noexcept
override
;
void
setPluginNamespace
(
char
const
*
pluginNamespace
)
noexcept
override
;
protected:
std
::
string
const
mLayerName
;
std
::
string
mNamespace
;
cuda_unique_ptr
<
float
>
mGammaDev
;
cuda_unique_ptr
<
float
>
mBetaDev
;
std
::
vector
<
void
*>
mIdsEmbPtrs
;
size_t
mLd
;
// leading dim = hidden size
std
::
vector
<
int32_t
>
mIdsVocabSize
;
WeightsWithOwnership
mBeta
;
WeightsWithOwnership
mGamma
;
nvinfer1
::
DataType
mType
;
std
::
vector
<
nvinfer1
::
Weights
>
mIdsEmb_
;
int32_t
nbLookupTables_
=
0
;
};
class
EmbLayerNormPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
EmbLayerNormPluginCreator
();
char
const
*
getPluginName
()
const
noexcept
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
noexcept
override
;
void
setPluginNamespace
(
char
const
*
pluginNamespace
)
noexcept
override
;
char
const
*
getPluginNamespace
()
const
noexcept
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
char
const
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
noexcept
override
;
char
const
*
getPluginVersion
()
const
noexcept
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
char
const
*
name
,
void
const
*
serialData
,
size_t
serialLength
)
noexcept
override
;
protected:
static
nvinfer1
::
PluginFieldCollection
mFC
;
static
std
::
vector
<
nvinfer1
::
PluginField
>
mPluginAttributes
;
std
::
string
mNamespace
;
};
REGISTER_TRT_PLUGIN_V2
(
EmbLayerNormPluginCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu
浏览文件 @
22bfa579
...
...
@@ -39,7 +39,8 @@ constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE
{
"1"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON
{
"2"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_NAME
{
"ManyEmbLayerNormPluginDynamic"
};
char
const
*
EMB_LAYER_NORM_VAR_SEQLEN_NAME
{
"ManyEmbLayerNormVarlenPluginDynamic"
};
// Static class fields initialization
nvinfer1
::
PluginFieldCollection
EmbLayerNormVarSeqlenPluginBaseCreator
::
mFC
{};
std
::
vector
<
nvinfer1
::
PluginField
>
...
...
@@ -167,7 +168,6 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions(
assert
(
inputs
[
i
].
nbDims
==
inputs
[
1
].
nbDims
);
// same shape
}
assert
(
inputs
[
0
].
nbDims
==
1
);
// pos_id: B+1
assert
(
outputIndex
==
0
||
outputIndex
==
1
);
if
(
outputIndex
==
0
)
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
4
;
...
...
@@ -176,25 +176,32 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions(
ret
.
d
[
2
]
=
exprBuilder
.
constant
(
1
);
ret
.
d
[
3
]
=
exprBuilder
.
constant
(
1
);
return
ret
;
}
else
if
(
outputIndex
==
1
)
{
// This is a hack: we just report some mask size and rely the plugins to
// play nicely together.
// At runtime, depending on the actual maxSeqlen, the size might be
// different.
int32_t
maskSize_
=
packedMaskSize384
;
auto
maskSize
=
exprBuilder
.
constant
(
maskSize_
);
auto
fp16maskSize
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
maskSize
,
*
exprBuilder
.
constant
(
2
));
auto
Bplus1
=
inputs
[
0
].
d
[
0
];
// pos_id
auto
one
=
exprBuilder
.
constant
(
1
);
auto
B
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
Bplus1
,
*
one
);
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
2
;
ret
.
d
[
0
]
=
B
;
ret
.
d
[
1
]
=
fp16maskSize
;
return
ret
;
}
else
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
1
;
ret
.
d
[
0
]
=
inputs
[
nbInputs
-
1
].
d
[
1
];
// mask id: max seqlen
return
ret
;
}
// This is a hack: we just report some mask size and rely the plugins to play
// nicely together.
// At runtime, depending on the actual maxSeqlen, the size might be
// different.
int32_t
maskSize_
=
packedMaskSize384
;
auto
maskSize
=
exprBuilder
.
constant
(
maskSize_
);
auto
fp16maskSize
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
maskSize
,
*
exprBuilder
.
constant
(
2
));
auto
Bplus1
=
inputs
[
0
].
d
[
0
];
// pos_id
auto
one
=
exprBuilder
.
constant
(
1
);
auto
B
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
Bplus1
,
*
one
);
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
2
;
ret
.
d
[
0
]
=
B
;
ret
.
d
[
1
]
=
fp16maskSize
;
return
ret
;
}
nvinfer1
::
DimsExprs
EmbLayerNormVarSeqlenPluginMTron
::
getOutputDimensions
(
...
...
@@ -209,14 +216,20 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions(
assert
(
inputs
[
i
].
nbDims
==
inputs
[
1
].
nbDims
);
// same shape
}
assert
(
inputs
[
0
].
nbDims
==
1
);
// pos_id: B+1
assert
(
outputIndex
==
0
||
outputIndex
==
1
);
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
4
;
ret
.
d
[
0
]
=
inputs
[
1
].
d
[
0
];
ret
.
d
[
1
]
=
exprBuilder
.
constant
(
mLd
);
ret
.
d
[
2
]
=
exprBuilder
.
constant
(
1
);
ret
.
d
[
3
]
=
exprBuilder
.
constant
(
1
);
return
ret
;
if
(
outputIndex
==
0
||
outputIndex
==
1
)
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
4
;
ret
.
d
[
0
]
=
inputs
[
1
].
d
[
0
];
// sum of seq length
ret
.
d
[
1
]
=
exprBuilder
.
constant
(
mLd
);
ret
.
d
[
2
]
=
exprBuilder
.
constant
(
1
);
ret
.
d
[
3
]
=
exprBuilder
.
constant
(
1
);
return
ret
;
}
else
{
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
1
;
ret
.
d
[
0
]
=
inputs
[
nbInputs
-
1
].
d
[
1
];
// mask id: max seqlen
return
ret
;
}
}
bool
EmbLayerNormVarSeqlenPluginBase
::
supportsFormatCombination
(
...
...
@@ -224,7 +237,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
nvinfer1
::
PluginTensorDesc
const
*
inOut
,
int32_t
nbInputs
,
int32_t
nbOutputs
)
noexcept
{
assert
(
nbOutputs
==
2
);
assert
(
nbOutputs
==
3
);
nvinfer1
::
PluginTensorDesc
const
&
desc
=
inOut
[
pos
];
if
(
desc
.
format
!=
nvinfer1
::
TensorFormat
::
kLINEAR
)
{
return
false
;
...
...
@@ -241,8 +254,8 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
return
desc
.
type
==
prev
.
type
&&
desc
.
dims
.
nbDims
==
1
&&
desc
.
dims
.
d
[
0
]
==
prev
.
dims
.
d
[
0
];
}
if
(
pos
==
nbInputs
-
1
)
{
// ma
x seq length
return
desc
.
dims
.
nbDims
==
1
;
if
(
pos
==
nbInputs
-
1
)
{
// ma
sk id
return
desc
.
type
==
prev
.
type
;
}
// embedded sequence
if
(
pos
==
nbInputs
)
{
...
...
@@ -250,8 +263,14 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
desc
.
dims
.
d
[
0
]
==
inOut
[
1
].
dims
.
d
[
0
]
&&
desc
.
dims
.
d
[
2
]
==
1
&&
desc
.
dims
.
d
[
3
]
==
1
;
}
// mask
return
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
;
// mask(HFace) or pre_layernorm_bias(MTron)
if
(
pos
==
nbInputs
+
1
)
{
return
desc
.
type
==
prev
.
type
;
}
// max seqlen
if
(
pos
==
nbInputs
+
2
)
{
return
desc
.
type
==
prev
.
type
;
}
}
void
checkConfigurationInputs
(
nvinfer1
::
DynamicPluginTensorDesc
const
*
inputs
,
...
...
@@ -259,8 +278,7 @@ void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs,
nvinfer1
::
DynamicPluginTensorDesc
const
*
outputs
,
int32_t
nbOutputs
)
noexcept
{
// Validate input arguments
// assert(nbInputs == 4);
assert
(
nbOutputs
==
2
);
assert
(
nbOutputs
==
3
);
assert
(
inputs
[
0
].
desc
.
dims
.
nbDims
==
1
);
assert
(
inputs
[
0
].
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
);
for
(
int
i
=
1
;
i
<
nbInputs
-
1
;
++
i
)
{
...
...
@@ -671,7 +689,7 @@ char const* EmbLayerNormVarSeqlenPluginMTron::getPluginVersion()
}
int32_t
EmbLayerNormVarSeqlenPluginBase
::
getNbOutputs
()
const
noexcept
{
return
2
;
return
3
;
}
int32_t
EmbLayerNormVarSeqlenPluginHFace
::
initialize
()
noexcept
{
...
...
paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h
浏览文件 @
22bfa579
...
...
@@ -194,7 +194,6 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
cuda_unique_ptr
<
float
>
mGammaDev
;
cuda_unique_ptr
<
float
>
mBetaDev
;
std
::
vector
<
void
*>
mIdsEmbPtrs
;
// std::vector<void*> mIdsEmbDev;
size_t
mLd
;
// leading dim = hidden size
std
::
vector
<
int32_t
>
mIdsVocabSize
;
WeightsWithOwnership
mBeta
;
...
...
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc
浏览文件 @
22bfa579
...
...
@@ -28,11 +28,13 @@ limitations under the License. */
namespace
paddle
{
namespace
inference
{
#if defined _WIN32
#else
TEST
(
AnalysisPredictor
,
no_fp16
)
{
std
::
vector
<
float
>
result
=
{
0.597841
,
0.219972
,
0.182187
};
trt_ernie
(
false
,
result
);
}
#endif
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h
浏览文件 @
22bfa579
...
...
@@ -38,23 +38,23 @@ static void run(const AnalysisConfig& config, std::vector<float>* out_data) {
int
run_batch
=
1
;
const
int
run_seq_len
=
128
;
std
::
vector
<
int
64
_t
>
tmp_input
;
std
::
vector
<
int
32
_t
>
tmp_input
;
std
::
vector
<
float
>
tmp_four_input
;
tmp_input
.
reserve
(
run_batch
*
run_seq_len
);
tmp_four_input
.
reserve
(
run_batch
*
run_seq_len
);
int
64
_t
i0
[
run_seq_len
]
=
{
int
32
_t
i0
[
run_seq_len
]
=
{
1
,
3558
,
4
,
75
,
491
,
89
,
340
,
313
,
93
,
4
,
255
,
10
,
75
,
321
,
4095
,
1902
,
4
,
134
,
49
,
75
,
311
,
14
,
44
,
178
,
543
,
15
,
12043
,
2
,
75
,
201
,
340
,
9
,
14
,
44
,
486
,
218
,
1140
,
279
,
12043
,
2
};
int
64
_t
i1
[
run_seq_len
]
=
{
int
32
_t
i1
[
run_seq_len
]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
int
64
_t
i2
[
run_seq_len
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
int
32
_t
i2
[
run_seq_len
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
};
...
...
@@ -136,11 +136,7 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
precision
=
AnalysisConfig
::
Precision
::
kHalf
;
}
#if defined _WIN32
#else
config
.
EnableTensorRtEngine
(
1
<<
30
,
1
,
5
,
precision
,
true
,
false
);
#endif
config
.
SetTRTDynamicShapeInfo
(
min_input_shape
,
max_input_shape
,
opt_input_shape
);
AnalysisConfig
*
config_deser
=
new
AnalysisConfig
(
config
);
...
...
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
浏览文件 @
22bfa579
...
...
@@ -33,18 +33,18 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data, int bs) {
const
int
run_seq_len
=
128
;
size_t
len
=
run_batch
*
run_seq_len
;
int
64
_t
i0_bs1
[
run_seq_len
]
=
{
int
32
_t
i0_bs1
[
run_seq_len
]
=
{
1
,
3558
,
4
,
75
,
491
,
89
,
340
,
313
,
93
,
4
,
255
,
10
,
75
,
321
,
4095
,
1902
,
4
,
134
,
49
,
75
,
311
,
14
,
44
,
178
,
543
,
15
,
12043
,
2
,
75
,
201
,
340
,
9
,
14
,
44
,
486
,
218
,
1140
,
279
,
12043
,
2
};
int
64
_t
i1_bs1
[
run_seq_len
]
=
{
int
32
_t
i1_bs1
[
run_seq_len
]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
int
64
_t
i2_bs1
[
run_seq_len
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
int
32
_t
i2_bs1
[
run_seq_len
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
};
...
...
@@ -52,7 +52,7 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data, int bs) {
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
std
::
vector
<
int
64
_t
>
i0_data
(
len
),
i1_data
(
len
),
i2_data
(
len
);
std
::
vector
<
int
32
_t
>
i0_data
(
len
),
i1_data
(
len
),
i2_data
(
len
);
std
::
vector
<
float
>
i3_data
(
len
);
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录