Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e5cf75d8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e5cf75d8
编写于
12月 01, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] General optimization for no_varlen multihead (#48469)
* general optimization for no_varlen multihead
上级
aa892113
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
595 addition
and
145 deletion
+595
-145
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+0
-3
paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
...framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
+2
-2
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
+2
-2
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+115
-8
paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc
...nference/tensorrt/convert/transformer_input_convert_op.cc
+1
-1
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
.../fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
+0
-1
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu
...rence/tensorrt/plugin/transformer_input_convert_plugin.cu
+0
-122
paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.cu
...ensorrt/plugin/transformer_input_output_convert_plugin.cu
+356
-0
paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h
...tensorrt/plugin/transformer_input_output_convert_plugin.h
+115
-2
未找到文件。
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
浏览文件 @
e5cf75d8
...
...
@@ -439,9 +439,6 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
"remove_padding pass."
;
return
;
}
fc_op
->
Op
()
->
RemoveAttr
(
"in_num_col_dims"
);
fc_op
->
Op
()
->
SetAttr
(
"in_num_col_dims"
,
1
);
insert_remove_padding_op
(
fc_input
,
fc_op
);
insert_recover_padding_op
(
fc_op
,
fc_op
->
outputs
[
0
]);
found_subgraph_count
++
;
...
...
paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
e5cf75d8
...
...
@@ -441,14 +441,14 @@ void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
std
::
string
mask_id
=
Get
<
std
::
string
>
(
"tensorrt_transformer_maskid"
);
if
((
use_varseqlen
&&
pos_id
!=
""
&&
mask_id
!=
""
)
||
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
))
{
(
!
use_varseqlen
&&
pos_id
==
""
))
{
VLOG
(
3
)
<<
"start trt_embedding_eltwise_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id
, set mask_id
. Please "
"pos_id. Please "
"reconfig"
));
}
graph
->
Set
(
kEmbEltwiseLayernormPass
,
new
bool
(
true
));
...
...
paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc
浏览文件 @
e5cf75d8
...
...
@@ -1637,14 +1637,14 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen_trt_multihead_matmul_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id
, set mask_id
. Please "
"pos_id. Please "
"reconfig"
));
}
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
...
...
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
浏览文件 @
e5cf75d8
...
...
@@ -207,14 +207,14 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
"trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen"
));
}
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
&&
mask_id
==
""
)
{
}
else
if
(
!
use_varseqlen
&&
pos_id
==
""
)
{
VLOG
(
3
)
<<
"start no_varseqlen trt_skip_layernorm_fuse_pass"
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id
, set mask_id
. Please "
"pos_id. Please "
"reconfig"
));
}
}
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
e5cf75d8
...
...
@@ -332,7 +332,7 @@ class FcOpConverter : public OpConverter {
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if
(
x_dim
.
nbDims
==
4
&&
x_
num_col_dims
==
1
)
{
if
(
x_dim
.
nbDims
==
4
&&
x_
dim
.
d
[
2
]
==
1
&&
x_dim
.
d
[
3
]
==
1
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv1x1 layer
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
e5cf75d8
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -87,7 +88,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_
->
tensorrt_transformer_posid
()
!=
""
&&
engine_
->
tensorrt_transformer_maskid
()
!=
""
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
flag_varseqlen
)
{
if
(
engine_
->
tensorrt_transformer_maskid
()
!=
""
)
{
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kFloat32
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use use_varseqlen must be int8 or half, not float32."
));
...
...
@@ -98,8 +99,100 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
Weights
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
"max_seqlen_tensor"
);
auto
pos_id_tensor
=
engine_
->
GetITensor
(
"pos_id"
);
nvinfer1
::
ITensor
*
mask_tensor
;
nvinfer1
::
ITensor
*
pos_id_tensor
;
nvinfer1
::
ITensor
*
max_seqlen_tensor
;
auto
*
new_input
=
input
;
if
(
flag_varseqlen
)
{
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
pos_id_tensor
=
engine_
->
GetITensor
(
"pos_id"
);
max_seqlen_tensor
=
engine_
->
GetITensor
(
"max_seqlen_tensor"
);
}
else
{
auto
*
bias_qk_tensor
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"BiasQK"
).
front
());
auto
bias_qk_dims
=
bias_qk_tensor
->
getDimensions
();
PADDLE_ENFORCE_EQ
(
bias_qk_dims
.
nbDims
,
4
,
platform
::
errors
::
InvalidArgument
(
"The rank of Multihead Matmul'BiasQK must be "
"4, but got rank is %d."
,
bias_qk_dims
.
nbDims
));
nvinfer1
::
Dims
start_dims
=
bias_qk_dims
;
start_dims
.
d
[
0
]
=
0
;
start_dims
.
d
[
1
]
=
0
;
start_dims
.
d
[
2
]
=
0
;
start_dims
.
d
[
3
]
=
0
;
nvinfer1
::
Dims
size_dims
=
bias_qk_dims
;
nvinfer1
::
Dims
step_dims
=
bias_qk_dims
;
step_dims
.
d
[
0
]
=
1
;
step_dims
.
d
[
1
]
=
1
;
step_dims
.
d
[
2
]
=
1
;
step_dims
.
d
[
3
]
=
1
;
auto
*
shape_tensor
=
Shape
(
bias_qk_tensor
);
// (b,n,m,m) -> (b,1,m,1)
std
::
vector
<
nvinfer1
::
ITensor
*>
size_vec_tensor
;
size_vec_tensor
.
push_back
(
GetEleTensorOfShape
(
shape_tensor
,
0
));
size_vec_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
size_vec_tensor
.
push_back
(
GetEleTensorOfShape
(
shape_tensor
,
2
));
size_vec_tensor
.
push_back
(
Add1DConstantLayer
(
1
));
auto
*
size_tensor
=
Concat
(
size_vec_tensor
);
auto
*
slice_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
bias_qk_tensor
,
start_dims
,
size_dims
,
step_dims
);
slice_layer
->
setInput
(
2
,
*
size_tensor
);
// half -> bool
auto
*
cast_layer_0
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Identity
,
*
slice_layer
->
getOutput
(
0
));
cast_layer_0
->
setOutputType
(
0
,
nvinfer1
::
DataType
::
kBOOL
);
// bool kNOT
auto
*
not_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Unary
,
*
cast_layer_0
->
getOutput
(
0
),
nvinfer1
::
UnaryOperation
::
kNOT
);
// bool -> int32
auto
*
cast_layer_1
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Identity
,
*
not_layer
->
getOutput
(
0
));
cast_layer_1
->
setOutputType
(
0
,
nvinfer1
::
DataType
::
kINT32
);
// Calculate the number of 1 : (b,1,m,1) -> (b)
uint32_t
reduce_dim_0
=
0
;
reduce_dim_0
|=
1
<<
1
;
// 00000000000000000000000000000010
reduce_dim_0
|=
1
<<
2
;
// 00000000000000000000000000000110
reduce_dim_0
|=
1
<<
3
;
// 00000000000000000000000000001110
bool
keep_dim
=
false
;
nvinfer1
::
ReduceOperation
reduce_type
=
nvinfer1
::
ReduceOperation
::
kSUM
;
auto
*
reduce_sum_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
cast_layer_1
->
getOutput
(
0
),
reduce_type
,
reduce_dim_0
,
keep_dim
);
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs_transformer
;
inputs_transformer
.
emplace_back
(
input
);
inputs_transformer
.
emplace_back
(
reduce_sum_layer
->
getOutput
(
0
));
// (b,m)
plugin
::
TransformerInputConvertPlugin
*
plugin
=
new
plugin
::
TransformerInputConvertPlugin
();
nvinfer1
::
ILayer
*
transformer_input_layer
=
engine_
->
AddDynamicPlugin
(
inputs_transformer
.
data
(),
inputs_transformer
.
size
(),
plugin
);
new_input
=
transformer_input_layer
->
getOutput
(
0
);
mask_tensor
=
transformer_input_layer
->
getOutput
(
1
);
pos_id_tensor
=
transformer_input_layer
->
getOutput
(
2
);
max_seqlen_tensor
=
transformer_input_layer
->
getOutput
(
3
);
}
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused multihead_matmul op: use_varseqlen and "
"with_interleaved"
;
...
...
@@ -111,7 +204,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
float
dp_probs
=
1.0
/
127.0
;
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
engine_
,
Convolution
,
*
new_
input
,
n
,
nv_ksize
,
weight
,
bias
);
fc_layer
->
setName
(
(
"Multihead: Convolution/FullyConnected: (Output: "
+
output_name
+
")"
)
...
...
@@ -220,10 +313,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
engine_
,
Convolution
,
*
new_
input
,
n
,
nv_ksize
,
weight
,
bias
);
}
else
{
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
,
bias
);
engine_
,
FullyConnected
,
*
new_
input
,
n
,
weight
,
bias
);
}
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
))
{
...
...
@@ -282,14 +375,28 @@ class MultiheadMatMulOpConverter : public OpConverter {
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
"qkv_plugin_mask"
)
);
plugin_inputs
.
emplace_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
pos_id_tensor
);
plugin_inputs
.
emplace_back
(
max_seqlen_tensor
);
// max_seqlen, eval_placeholder_3
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
// recover no_varlen output
if
(
!
flag_varseqlen
)
{
std
::
vector
<
nvinfer1
::
ITensor
*>
output_transformer
;
output_transformer
.
emplace_back
(
plugin_layer
->
getOutput
(
0
));
output_transformer
.
emplace_back
(
input
);
output_transformer
.
emplace_back
(
pos_id_tensor
);
plugin
::
TransformerOutputConvertPlugin
*
plugin
=
new
plugin
::
TransformerOutputConvertPlugin
();
nvinfer1
::
ILayer
*
transformer_output_layer
=
engine_
->
AddDynamicPlugin
(
output_transformer
.
data
(),
output_transformer
.
size
(),
plugin
);
layer
=
transformer_output_layer
;
}
}
}
else
{
if
(
input_dims
.
d
[
1
]
<=
384
&&
!
bias_qk_attr
&&
...
...
paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc
浏览文件 @
e5cf75d8
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_
output_
convert_plugin.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
e5cf75d8
...
...
@@ -26,7 +26,7 @@ list(
deformable_conv_op_plugin.cu
matmul_op_int8_plugin.cu
multihead_matmul_roformer_plugin.cu
transformer_input_convert_plugin.cu
transformer_input_
output_
convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
c_allreduce_op_plugin.cu
...
...
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
浏览文件 @
e5cf75d8
...
...
@@ -105,7 +105,6 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
input_desc
=
inputDesc
[
0
];
const
half
*
input0
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
...
...
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu
已删除
100644 → 0
浏览文件 @
aa892113
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
__global__
void
TransformerInputConvertKernel
(
const
int64_t
*
input
,
int32_t
*
output0
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
__shared__
int32_t
shared_data
;
if
(
threadIdx
.
x
==
static_cast
<
int
>
(
input
[
tid
]))
{
atomicAdd
(
&
shared_data
,
1
);
}
output0
[
0
]
=
0
;
output0
[
blockIdx
.
x
+
1
]
=
shared_data
;
__syncthreads
();
for
(
int
i
=
0
;
i
<
blockDim
.
x
;
++
i
)
{
output0
[
i
+
1
]
+=
output0
[
i
];
}
}
nvinfer1
::
DataType
TransformerInputConvertPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
nvinfer1
::
DataType
::
kINT32
;
}
nvinfer1
::
DimsExprs
TransformerInputConvertPlugin
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
output_dims
{};
output_dims
.
nbDims
=
1
;
if
(
outputIndex
==
0
)
{
// PosId
const
auto
*
one
=
exprBuilder
.
constant
(
1
);
output_dims
.
d
[
0
]
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
inputs
[
0
].
d
[
0
],
*
one
);
}
else
{
// MaxSeqlen
output_dims
.
d
[
0
]
=
inputs
[
0
].
d
[
1
];
}
return
output_dims
;
}
bool
TransformerInputConvertPlugin
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"Must have 1 inputs, "
"but got %d input(s). "
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
getNbOutputs
(),
platform
::
errors
::
InvalidArgument
(
"Must have 2 output, "
"but got %d output(s). "
,
nbOutputs
));
if
(
pos
==
0
)
{
// input
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
else
{
// output0, output1
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
void
TransformerInputConvertPlugin
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
detachFromContext
()
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
terminate
()
TRT_NOEXCEPT
{}
int
TransformerInputConvertPlugin
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
input_desc
=
inputDesc
[
0
];
const
int64_t
*
input
=
static_cast
<
const
int64_t
*>
(
inputs
[
0
]);
int32_t
*
output0
=
static_cast
<
int32_t
*>
(
outputs
[
0
]);
// PosId
// int32_t* output1 = static_cast<int32_t*>(outputs[1]); // MaxSeqlen
const
int32_t
num_blocks
=
input_desc
.
dims
.
d
[
0
];
// batchs
const
int32_t
num_threads
=
input_desc
.
dims
.
d
[
1
];
// max sequnce length
TransformerInputConvertKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input
,
output0
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.cu
0 → 100644
浏览文件 @
e5cf75d8
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
#include "cub/cub.cuh"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
__global__
void
remove_padding_kernel
(
const
half
*
input0
,
const
int32_t
*
input1
,
half
*
output
)
{
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
if
(
blockIdx
.
y
<
seqence_length
)
{
output
[(
input1
[
blockIdx
.
x
]
+
blockIdx
.
y
)
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
input0
[
word_id
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
];
}
}
__global__
void
recover_padding_kernel
(
const
half
*
input0
,
const
int32_t
*
input1
,
half
*
output
)
{
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
if
(
blockIdx
.
y
<
seqence_length
)
{
output
[
word_id
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
input0
[(
input1
[
blockIdx
.
x
]
+
blockIdx
.
y
)
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
];
}
else
{
output
[
word_id
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
0
;
}
}
nvinfer1
::
DataType
TransformerInputConvertPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
if
(
index
==
0
)
{
// new input
return
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
index
==
1
)
{
// mask
return
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
index
==
2
)
{
// pos id
return
nvinfer1
::
DataType
::
kINT32
;
}
else
if
(
index
==
3
)
{
// max_seqlen_tensor
return
nvinfer1
::
DataType
::
kHALF
;
}
}
nvinfer1
::
DimsExprs
TransformerInputConvertPlugin
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
constexpr
size_t
threadsPerCta384
=
1
*
8
*
32
;
constexpr
size_t
xmmasM384
=
24
;
constexpr
size_t
packedMaskSize384
=
xmmasM384
*
threadsPerCta384
;
int32_t
maskSize_
=
packedMaskSize384
;
auto
maskSize
=
exprBuilder
.
constant
(
maskSize_
);
auto
fp16maskSize
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
maskSize
,
*
exprBuilder
.
constant
(
2
));
auto
one
=
exprBuilder
.
constant
(
1
);
auto
B
=
inputs
[
0
].
d
[
0
];
auto
MaxLength
=
inputs
[
0
].
d
[
1
];
auto
Hidden
=
inputs
[
0
].
d
[
2
];
nvinfer1
::
DimsExprs
output_dims
;
if
(
outputIndex
==
0
)
{
// new input
output_dims
.
nbDims
=
4
;
output_dims
.
d
[
0
]
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
B
,
*
MaxLength
);
output_dims
.
d
[
1
]
=
Hidden
;
output_dims
.
d
[
2
]
=
exprBuilder
.
constant
(
1
);
output_dims
.
d
[
3
]
=
exprBuilder
.
constant
(
1
);
}
else
if
(
outputIndex
==
1
)
{
// mask
output_dims
.
nbDims
=
2
;
output_dims
.
d
[
0
]
=
B
;
output_dims
.
d
[
1
]
=
fp16maskSize
;
}
else
if
(
outputIndex
==
2
)
{
// pos id
output_dims
.
nbDims
=
1
;
output_dims
.
d
[
0
]
=
exprBuilder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
B
,
*
one
);
}
else
if
(
outputIndex
==
3
)
{
// max_seqlen_tensor
output_dims
.
nbDims
=
1
;
output_dims
.
d
[
0
]
=
MaxLength
;
}
return
output_dims
;
}
bool
TransformerInputConvertPlugin
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
2
,
platform
::
errors
::
InvalidArgument
(
"TransformerInputConvertPlugin must have 2 inputs, "
"but got %d input(s). "
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
4
,
platform
::
errors
::
InvalidArgument
(
"TransformerInputConvertPlugin must have 4 outputs, "
"but got %d output(s). "
,
nbOutputs
));
if
(
pos
==
0
)
{
// input
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
pos
==
1
)
{
// reducesum_qk_bias
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
;
}
else
if
(
pos
==
2
)
{
// new input
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
pos
==
3
)
{
// mask
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
pos
==
4
)
{
// pos id
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
;
}
else
if
(
pos
==
5
)
{
// max_seqlen_tensor
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
}
void
TransformerInputConvertPlugin
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
detachFromContext
()
TRT_NOEXCEPT
{}
void
TransformerInputConvertPlugin
::
terminate
()
TRT_NOEXCEPT
{}
int
TransformerInputConvertPlugin
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
// input(no_varlen), reducesum_qk_bias, input(varlen), mask, pos_id,
// max_seqlen_tensor
const
half
*
input0
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// input(no_varlen)
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// reducesum_qk_bias
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
// input(varlen)
int32_t
*
output2
=
static_cast
<
int32_t
*>
(
outputs
[
2
]);
// pos_id
const
auto
input0_desc
=
inputDesc
[
0
];
const
int32_t
B
=
input0_desc
.
dims
.
d
[
0
];
// batchs
const
int32_t
MaxLength
=
input0_desc
.
dims
.
d
[
1
];
// max token length
const
int32_t
HiddenSize
=
input0_desc
.
dims
.
d
[
2
];
// hidden size
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
input1
,
output2
,
B
+
1
);
// Allocate temporary storage
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
// Run exclusive prefix sum
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
input1
,
output2
,
B
+
1
);
const
int32_t
vector_length
=
HiddenSize
;
int32_t
num_threads
;
if
(
vector_length
<
1024
)
{
num_threads
=
vector_length
;
}
else
{
if
(
vector_length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
vector_length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
vector_length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
vector_length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
vector_length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
vector_length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
vector_length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
vector_length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
vector_length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
const
dim3
num_blocks
(
B
,
MaxLength
,
vector_length
/
num_threads
);
// batchs, max sequnce length, input0.dims.d[2]/*
remove_padding_kernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input0
,
output2
,
output0
);
// input(no_varlen), pos_id, input(varlen)
return
cudaGetLastError
()
!=
cudaSuccess
;
}
nvinfer1
::
DataType
TransformerOutputConvertPlugin
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
if
(
index
==
0
)
{
return
nvinfer1
::
DataType
::
kHALF
;
}
}
nvinfer1
::
DimsExprs
TransformerOutputConvertPlugin
::
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
TRT_NOEXCEPT
{
nvinfer1
::
DimsExprs
output_dims
;
if
(
outputIndex
==
0
)
{
output_dims
=
inputs
[
1
];
}
return
output_dims
;
}
bool
TransformerOutputConvertPlugin
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
nbInputs
,
3
,
platform
::
errors
::
InvalidArgument
(
"TransformerOutputConvertPlugin must have 3 inputs, "
"but got %d input(s). "
,
nbInputs
));
PADDLE_ENFORCE_EQ
(
nbOutputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"TransformerOutputConvertPlugin must have 1 output, "
"but got %d output(s). "
,
nbOutputs
));
if
(
pos
==
0
)
{
// qkv plugin output(varlen)
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
pos
==
1
)
{
// qkv plugin input(no_varlen)
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
pos
==
2
)
{
// pos id
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
;
}
else
if
(
pos
==
3
)
{
// qkv plugin output(no_varlen)
return
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
&&
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
;
}
}
void
TransformerOutputConvertPlugin
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
{}
void
TransformerOutputConvertPlugin
::
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
{}
void
TransformerOutputConvertPlugin
::
detachFromContext
()
TRT_NOEXCEPT
{}
void
TransformerOutputConvertPlugin
::
terminate
()
TRT_NOEXCEPT
{}
int
TransformerOutputConvertPlugin
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
half
*
input0
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// qkv plugin output(varlen)
const
half
*
input1
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// qkv plugin input(no_varlen)
const
int32_t
*
input2
=
static_cast
<
const
int32_t
*>
(
inputs
[
2
]);
// pos id
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
// qkv plugin output(no_varlen)
const
auto
input1_desc
=
inputDesc
[
1
];
const
int32_t
B
=
input1_desc
.
dims
.
d
[
0
];
// batchs
const
int32_t
MaxLength
=
input1_desc
.
dims
.
d
[
1
];
// max token length
const
int32_t
HiddenSize
=
input1_desc
.
dims
.
d
[
2
];
// hidden size
const
int32_t
vector_length
=
HiddenSize
;
int32_t
num_threads
;
if
(
vector_length
<
1024
)
{
num_threads
=
vector_length
;
}
else
{
if
(
vector_length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
vector_length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
vector_length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
vector_length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
vector_length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
vector_length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
vector_length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
vector_length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
vector_length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
const
dim3
num_blocks
(
B
,
MaxLength
,
vector_length
/
num_threads
);
// batchs, max sequnce length
// (mask_id.dims.d[1]),
// input.dims.d[1]/*
recover_padding_kernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input0
,
input2
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h
→
paddle/fluid/inference/tensorrt/plugin/transformer_input_
output_
convert_plugin.h
浏览文件 @
e5cf75d8
...
...
@@ -40,14 +40,14 @@ class TransformerInputConvertPlugin : public DynamicPluginTensorRT {
return
"transformer_input_convert_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
2
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
4
;
}
int
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
terminate
()
TRT_NOEXCEPT
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
nvinfer1
::
IExprBuilder
&
exprBuilder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
...
...
@@ -134,7 +134,120 @@ class TransformerInputConvertPluginCreator : public nvinfer1::IPluginCreator {
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
};
class
TransformerOutputConvertPlugin
:
public
DynamicPluginTensorRT
{
public:
TransformerOutputConvertPlugin
()
{}
TransformerOutputConvertPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
TransformerOutputConvertPlugin
*
ptr
=
new
TransformerOutputConvertPlugin
();
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"transformer_output_convert_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
void
terminate
()
TRT_NOEXCEPT
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
outputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
attachToContext
(
cudnnContext
*
cudnnContext
,
cublasContext
*
cublasContext
,
nvinfer1
::
IGpuAllocator
*
gpuAllocator
)
TRT_NOEXCEPT
override
;
void
detachFromContext
()
TRT_NOEXCEPT
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
protected:
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
0
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{}
};
class
TransformerOutputConvertPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
TransformerOutputConvertPluginCreator
()
{}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"transformer_output_convert_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
plugin_field
)
TRT_NOEXCEPT
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
void
const
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
TransformerOutputConvertPlugin
*
obj
=
new
TransformerOutputConvertPlugin
(
serial_data
,
serial_length
);
obj
->
setPluginNamespace
(
name
);
return
obj
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
};
REGISTER_TRT_PLUGIN_V2
(
TransformerInputConvertPluginCreator
);
REGISTER_TRT_PLUGIN_V2
(
TransformerOutputConvertPluginCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录