Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f4d9212d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f4d9212d
编写于
3月 23, 2021
作者:
W
Wilber
提交者:
GitHub
3月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
trt plugin upgrade to pluginv2ext (#31670)
上级
372ac08a
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
322 addition
and
32 deletion
+322
-32
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+1
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+8
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+7
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+3
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+0
-5
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+59
-10
paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc
paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc
+58
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+65
-13
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+110
-2
python/setup.py.in
python/setup.py.in
+11
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/split_op.cc
浏览文件 @
f4d9212d
...
@@ -101,7 +101,7 @@ class SplitOpConverter : public OpConverter {
...
@@ -101,7 +101,7 @@ class SplitOpConverter : public OpConverter {
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SplitPlugin
*
plugin
=
plugin
::
SplitPlugin
*
plugin
=
new
plugin
::
SplitPlugin
(
axis
,
output_lengths
,
with_fp16
);
new
plugin
::
SplitPlugin
(
axis
,
output_lengths
,
with_fp16
);
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
layer
=
engine_
->
AddPlugin
V2Ext
(
&
input
,
input_num
,
plugin
);
}
}
std
::
string
layer_name
=
"split (Output: "
;
std
::
string
layer_name
=
"split (Output: "
;
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
f4d9212d
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <string>
#include <string>
#include "cuda_runtime_api.h"
#include "cuda_runtime_api.h"
// NOLINT
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
...
@@ -353,6 +353,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
...
@@ -353,6 +353,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
return
network
()
->
addPluginExt
(
inputs
,
num_inputs
,
*
plugin
);
return
network
()
->
addPluginExt
(
inputs
,
num_inputs
,
*
plugin
);
}
}
nvinfer1
::
IPluginV2Layer
*
TensorRTEngine
::
AddPluginV2Ext
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRTV2Ext
*
plugin
)
{
owned_plugin_v2ext_
.
emplace_back
(
plugin
);
return
network
()
->
addPluginV2
(
inputs
,
num_inputs
,
*
plugin
);
}
void
TensorRTEngine
::
freshDeviceId
()
{
void
TensorRTEngine
::
freshDeviceId
()
{
int
count
;
int
count
;
cudaGetDeviceCount
(
&
count
);
cudaGetDeviceCount
(
&
count
);
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
f4d9212d
...
@@ -305,8 +305,14 @@ class TensorRTEngine {
...
@@ -305,8 +305,14 @@ class TensorRTEngine {
}
}
int
GetDeviceId
()
{
return
device_id_
;
}
int
GetDeviceId
()
{
return
device_id_
;
}
nvinfer1
::
IPluginLayer
*
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
nvinfer1
::
IPluginLayer
*
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRT
*
);
int
num_inputs
,
plugin
::
PluginTensorRT
*
);
nvinfer1
::
IPluginV2Layer
*
AddPluginV2Ext
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
num_inputs
,
plugin
::
PluginTensorRTV2Ext
*
plugin
);
void
SetTensorDynamicRange
(
nvinfer1
::
ITensor
*
tensor
,
float
range
)
{
void
SetTensorDynamicRange
(
nvinfer1
::
ITensor
*
tensor
,
float
range
)
{
quant_dynamic_range_
[
tensor
]
=
range
;
quant_dynamic_range_
[
tensor
]
=
range
;
}
}
...
@@ -414,6 +420,7 @@ class TensorRTEngine {
...
@@ -414,6 +420,7 @@ class TensorRTEngine {
itensor_map_
;
itensor_map_
;
std
::
vector
<
std
::
unique_ptr
<
plugin
::
PluginTensorRT
>>
owned_plugin_
;
std
::
vector
<
std
::
unique_ptr
<
plugin
::
PluginTensorRT
>>
owned_plugin_
;
std
::
vector
<
std
::
unique_ptr
<
plugin
::
PluginTensorRTV2Ext
>>
owned_plugin_v2ext_
;
// TensorRT related internal members
// TensorRT related internal members
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
f4d9212d
...
@@ -6,3 +6,6 @@ nv_library(tensorrt_plugin
...
@@ -6,3 +6,6 @@ nv_library(tensorrt_plugin
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
nv_test
(
test_split_plugin SRCS test_split_plugin.cc DEPS
paddle_framework
${
GLOB_OPERATOR_DEPS
}
tensorrt_plugin
)
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
浏览文件 @
f4d9212d
...
@@ -22,11 +22,6 @@ namespace inference {
...
@@ -22,11 +22,6 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
SplitPlugin
*
CreateSplitPluginDeserialize
(
const
void
*
buffer
,
size_t
length
)
{
return
new
SplitPlugin
(
buffer
,
length
);
}
REGISTER_TRT_PLUGIN
(
"split_plugin"
,
CreateSplitPluginDeserialize
);
template
<
typename
T
>
template
<
typename
T
>
__device__
int
upper_bound
(
T
const
*
vals
,
int
n
,
T
const
&
key
)
{
__device__
int
upper_bound
(
T
const
*
vals
,
int
n
,
T
const
&
key
)
{
int
i
=
0
;
int
i
=
0
;
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
浏览文件 @
f4d9212d
...
@@ -25,7 +25,7 @@ namespace inference {
...
@@ -25,7 +25,7 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
class
SplitPlugin
:
public
PluginTensorRT
{
class
SplitPlugin
:
public
PluginTensorRT
V2Ext
{
public:
public:
SplitPlugin
()
{}
SplitPlugin
()
{}
SplitPlugin
(
int
axis
,
std
::
vector
<
int
>
const
&
output_lengths
,
bool
with_fp16
)
SplitPlugin
(
int
axis
,
std
::
vector
<
int
>
const
&
output_lengths
,
bool
with_fp16
)
...
@@ -39,13 +39,20 @@ class SplitPlugin : public PluginTensorRT {
...
@@ -39,13 +39,20 @@ class SplitPlugin : public PluginTensorRT {
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
output_length_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
output_length_
);
}
}
SplitPlugin
*
clone
()
const
override
{
nvinfer1
::
IPluginV2Ext
*
clone
()
const
override
{
auto
*
ptr
=
new
SplitPlugin
(
axis_
,
output_length_
,
with_fp16_
);
SplitPlugin
*
ptr
=
new
SplitPlugin
(
axis_
,
output_length_
,
with_fp16_
);
ptr
->
setPluginNamespace
(
this
->
getPluginNamespace
());
ptr
->
shareData
(
this
);
ptr
->
shareData
(
this
);
return
ptr
;
return
ptr
;
}
}
const
char
*
getPluginType
()
const
override
{
return
"split_plugin"
;
}
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
override
{
return
input_types
[
0
];
}
const
char
*
getPluginType
()
const
override
{
return
"split_plugin_v2ext"
;
}
int
getNbOutputs
()
const
override
{
return
output_length_
.
size
();
}
int
getNbOutputs
()
const
override
{
return
output_length_
.
size
();
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
input_dims
,
const
nvinfer1
::
Dims
*
input_dims
,
...
@@ -53,17 +60,18 @@ class SplitPlugin : public PluginTensorRT {
...
@@ -53,17 +60,18 @@ class SplitPlugin : public PluginTensorRT {
int
initialize
()
override
;
int
initialize
()
override
;
void
terminate
()
override
;
void
terminate
()
override
;
int
enqueue
(
int
batch
S
ize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
int
enqueue
(
int
batch
_s
ize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
destroy
()
override
{
delete
this
;
}
protected:
protected:
size_t
getSerializationSize
()
override
{
size_t
getSerializationSize
()
const
override
{
return
SerializedSize
(
getPluginType
())
+
SerializedSize
(
axis
_
)
+
return
SerializedSize
(
axis_
)
+
SerializedSize
(
output_length
_
)
+
SerializedSize
(
output_length_
)
+
getBaseSerializationSize
();
getBaseSerializationSize
();
}
}
void
serialize
(
void
*
buffer
)
override
{
void
serialize
(
void
*
buffer
)
const
override
{
SerializeValue
(
&
buffer
,
getPluginType
());
serializeBase
(
buffer
);
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
axis_
);
SerializeValue
(
&
buffer
,
axis_
);
SerializeValue
(
&
buffer
,
output_length_
);
SerializeValue
(
&
buffer
,
output_length_
);
...
@@ -83,6 +91,47 @@ class SplitPlugin : public PluginTensorRT {
...
@@ -83,6 +91,47 @@ class SplitPlugin : public PluginTensorRT {
void
shareData
(
const
SplitPlugin
*
another
);
void
shareData
(
const
SplitPlugin
*
another
);
};
};
class
SplitPluginCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
SplitPluginCreator
()
{}
const
char
*
getPluginName
()
const
override
{
return
"split_plugin_v2ext"
;
}
const
char
*
getPluginVersion
()
const
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
{
// not implemented
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
{
auto
plugin
=
new
SplitPlugin
(
serial_data
,
serial_length
);
return
plugin
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
SplitPluginCreator
);
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
class
SplitPluginDynamic
:
public
DynamicPluginTensorRT
{
class
SplitPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
...
...
paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc
0 → 100644
浏览文件 @
f4d9212d
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
TEST
(
split_op_plugin
,
test_plugin
)
{
int
axis
=
1
;
std
::
vector
<
int
>
output_lengths
{
1
,
1
};
bool
with_fp16
=
false
;
std
::
vector
<
nvinfer1
::
DataType
>
input_types
{
nvinfer1
::
DataType
::
kFLOAT
};
std
::
vector
<
nvinfer1
::
Dims
>
input_dims
;
SplitPlugin
sp_plugin
(
axis
,
output_lengths
,
with_fp16
);
nvinfer1
::
Dims
in_dims
;
in_dims
.
nbDims
=
4
;
input_dims
.
push_back
(
in_dims
);
sp_plugin
.
configurePlugin
(
input_dims
.
data
(),
1
,
nullptr
,
2
,
input_types
.
data
(),
nullptr
,
nullptr
,
nullptr
,
nvinfer1
::
PluginFormat
::
kNCHW
,
4
);
sp_plugin
.
initialize
();
sp_plugin
.
getPluginType
();
sp_plugin
.
canBroadcastInputAcrossBatch
(
0
);
sp_plugin
.
getNbOutputs
();
auto
clone_plugin
=
sp_plugin
.
clone
();
clone_plugin
->
setPluginNamespace
(
"test"
);
clone_plugin
->
destroy
();
sp_plugin
.
getOutputDataType
(
0
,
input_types
.
data
(),
1
);
sp_plugin
.
terminate
();
}
TEST
(
split_op_plugin
,
test_plugin_creater
)
{
SplitPluginCreator
creator
;
creator
.
getFieldNames
();
creator
.
createPlugin
(
"test"
,
nullptr
);
creator
.
setPluginNamespace
(
"test"
);
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
浏览文件 @
f4d9212d
...
@@ -19,27 +19,50 @@ namespace inference {
...
@@ -19,27 +19,50 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
inline
void
Seria
(
void
*&
buffer
,
// NOLINT
const
std
::
vector
<
nvinfer1
::
Dims
>&
input_dims
,
size_t
max_batch_size
,
nvinfer1
::
DataType
data_type
,
nvinfer1
::
PluginFormat
data_format
,
bool
with_fp16
)
{
SerializeValue
(
&
buffer
,
input_dims
);
SerializeValue
(
&
buffer
,
max_batch_size
);
SerializeValue
(
&
buffer
,
data_type
);
SerializeValue
(
&
buffer
,
data_format
);
SerializeValue
(
&
buffer
,
with_fp16
);
}
inline
void
Deseria
(
void
const
*&
serial_data
,
size_t
&
serial_length
,
// NOLINT
std
::
vector
<
nvinfer1
::
Dims
>*
input_dims
,
size_t
*
max_batch_size
,
nvinfer1
::
DataType
*
data_type
,
nvinfer1
::
PluginFormat
*
data_format
,
bool
*
with_fp16
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
input_dims
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
max_batch_size
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
data_type
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
data_format
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
with_fp16
);
}
inline
size_t
SeriaSize
(
const
std
::
vector
<
nvinfer1
::
Dims
>&
input_dims
,
size_t
max_batch_size
,
nvinfer1
::
DataType
data_type
,
nvinfer1
::
PluginFormat
data_format
,
bool
with_fp16
)
{
return
(
SerializedSize
(
input_dims
)
+
SerializedSize
(
max_batch_size
)
+
SerializedSize
(
data_type
)
+
SerializedSize
(
data_format
)
+
SerializedSize
(
with_fp16
));
}
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
SerializeValue
(
&
buffer
,
input_dims_
);
Seria
(
buffer
,
input_dims_
,
max_batch_size_
,
data_type_
,
data_format_
,
SerializeValue
(
&
buffer
,
max_batch_size_
);
with_fp16_
);
SerializeValue
(
&
buffer
,
data_type_
);
SerializeValue
(
&
buffer
,
data_format_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
}
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serial_data
,
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serial_data
,
size_t
&
serial_length
)
{
size_t
&
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
input_dims_
);
Deseria
(
serial_data
,
serial_length
,
&
input_dims_
,
&
max_batch_size_
,
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
max_batch_size_
);
&
data_type_
,
&
data_format_
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
data_type_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
data_format_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
}
}
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
return
(
SerializedSize
(
input_dims_
)
+
SerializedSize
(
max_batch_size_
)
+
return
SeriaSize
(
input_dims_
,
max_batch_size_
,
data_type_
,
data_format_
,
SerializedSize
(
data_type_
)
+
SerializedSize
(
data_format_
)
+
with_fp16_
);
SerializedSize
(
with_fp16_
));
}
}
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
...
@@ -58,6 +81,35 @@ void PluginTensorRT::configureWithFormat(
...
@@ -58,6 +81,35 @@ void PluginTensorRT::configureWithFormat(
max_batch_size_
=
max_batch_size
;
max_batch_size_
=
max_batch_size
;
}
}
void
PluginTensorRTV2Ext
::
serializeBase
(
void
*&
buffer
)
const
{
Seria
(
buffer
,
input_dims_
,
max_batch_size_
,
data_type_
,
data_format_
,
with_fp16_
);
}
void
PluginTensorRTV2Ext
::
deserializeBase
(
void
const
*&
serial_data
,
size_t
&
serial_length
)
{
Deseria
(
serial_data
,
serial_length
,
&
input_dims_
,
&
max_batch_size_
,
&
data_type_
,
&
data_format_
,
&
with_fp16_
);
}
size_t
PluginTensorRTV2Ext
::
getBaseSerializationSize
()
const
{
return
SeriaSize
(
input_dims_
,
max_batch_size_
,
data_type_
,
data_format_
,
with_fp16_
);
}
void
PluginTensorRTV2Ext
::
configurePlugin
(
const
nvinfer1
::
Dims
*
input_dims
,
int32_t
nb_inputs
,
const
nvinfer1
::
Dims
*
output_dims
,
int32_t
nb_outputs
,
const
nvinfer1
::
DataType
*
input_types
,
const
nvinfer1
::
DataType
*
output_types
,
const
bool
*
input_is_broadcast
,
const
bool
*
output_is_broadcast
,
nvinfer1
::
PluginFormat
float_format
,
int32_t
max_batch_size
)
{
input_dims_
.
assign
(
input_dims
,
input_dims
+
nb_inputs
);
max_batch_size_
=
max_batch_size
;
data_format_
=
float_format
;
data_type_
=
input_types
[
0
];
}
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
浏览文件 @
f4d9212d
...
@@ -44,6 +44,7 @@ typedef std::function<PluginTensorRT*(const void*, size_t)>
...
@@ -44,6 +44,7 @@ typedef std::function<PluginTensorRT*(const void*, size_t)>
typedef
std
::
function
<
PluginTensorRT
*
(
void
)
>
PluginConstructFunc
;
typedef
std
::
function
<
PluginTensorRT
*
(
void
)
>
PluginConstructFunc
;
// Deprecated. Do not inherit this class, please refer to PluginTensorRTV2Ext
class
PluginTensorRT
:
public
nvinfer1
::
IPluginExt
{
class
PluginTensorRT
:
public
nvinfer1
::
IPluginExt
{
public:
public:
PluginTensorRT
()
:
with_fp16_
(
false
)
{}
PluginTensorRT
()
:
with_fp16_
(
false
)
{}
...
@@ -119,6 +120,114 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
...
@@ -119,6 +120,114 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
bool
with_fp16_
;
bool
with_fp16_
;
};
};
// TensorRT introduced IPluginV2Ext after 5.1, Paddle no longer supports
// versions before 5.1
class
PluginTensorRTV2Ext
:
public
nvinfer1
::
IPluginV2Ext
{
public:
PluginTensorRTV2Ext
()
:
with_fp16_
(
false
)
{}
PluginTensorRTV2Ext
(
const
void
*
serialized_data
,
size_t
length
)
{}
nvinfer1
::
Dims
const
&
getInputDims
(
int
index
)
const
{
return
input_dims_
.
at
(
index
);
}
size_t
getMaxBatchSize
()
const
{
return
max_batch_size_
;
}
nvinfer1
::
DataType
getDataType
()
const
{
return
data_type_
;
}
nvinfer1
::
PluginFormat
getDataFormat
()
const
{
return
data_format_
;
}
// The Func in IPluginV2Ext
virtual
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
=
0
;
virtual
bool
isOutputBroadcastAcrossBatch
(
int32_t
output_index
,
const
bool
*
input_is_broadcasted
,
int32_t
nb_inputs
)
const
{
return
false
;
}
virtual
bool
canBroadcastInputAcrossBatch
(
int32_t
input_index
)
const
{
return
false
;
}
void
configurePlugin
(
const
nvinfer1
::
Dims
*
input_dims
,
int32_t
nb_inputs
,
const
nvinfer1
::
Dims
*
output_dims
,
int32_t
nb_outputs
,
const
nvinfer1
::
DataType
*
input_types
,
const
nvinfer1
::
DataType
*
output_types
,
const
bool
*
input_is_broadcast
,
const
bool
*
output_is_broadcast
,
nvinfer1
::
PluginFormat
float_format
,
int32_t
max_batch_size
)
override
;
virtual
IPluginV2Ext
*
clone
()
const
=
0
;
void
attachToContext
(
cudnnContext
*
,
cublasContext
*
,
nvinfer1
::
IGpuAllocator
*
)
override
{}
void
detachFromContext
()
override
{}
// The Func in IPluginV2
virtual
const
char
*
getPluginType
()
const
=
0
;
const
char
*
getPluginVersion
()
const
override
{
return
"1"
;
}
virtual
int32_t
getNbOutputs
()
const
{
return
1
;
}
virtual
nvinfer1
::
Dims
getOutputDimensions
(
int32_t
index
,
const
nvinfer1
::
Dims
*
inputs
,
int32_t
nb_input
)
=
0
;
// Check format support. The default is FLOAT32 and NCHW.
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
override
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kNCHW
));
}
// Initialize the layer for execution.
// This is called when the engine is created.
int
initialize
()
override
{
return
0
;
}
// Shutdown the layer. This is called when the engine is destroyed
void
terminate
()
override
{}
// Find the workspace size required by the layer
size_t
getWorkspaceSize
(
int
)
const
override
{
return
0
;
}
// Execute the layer
virtual
int
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
// Find the size of the serialization buffer required
virtual
size_t
getSerializationSize
()
const
=
0
;
// Serialize the layer config to buffer.
// TensorRT will call this func to serialize the configuration of TensorRT
// engine. It should not be called by users.
virtual
void
serialize
(
void
*
buffer
)
const
=
0
;
virtual
void
destroy
()
=
0
;
void
setPluginNamespace
(
const
char
*
plugin_namespace
)
override
{
name_space_
=
plugin_namespace
;
}
const
char
*
getPluginNamespace
()
const
override
{
return
name_space_
.
c_str
();
}
protected:
void
deserializeBase
(
void
const
*&
serial_data
,
// NOLINT
size_t
&
serial_length
);
// NOLINT
size_t
getBaseSerializationSize
()
const
;
void
serializeBase
(
void
*&
buffer
)
const
;
// NOLINT
protected:
std
::
vector
<
nvinfer1
::
Dims
>
input_dims_
;
size_t
max_batch_size_
;
nvinfer1
::
DataType
data_type_
;
nvinfer1
::
PluginFormat
data_format_
;
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs_
;
bool
with_fp16_
;
private:
std
::
string
name_space_
;
};
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
class
DynamicPluginTensorRT
:
public
nvinfer1
::
IPluginV2DynamicExt
{
class
DynamicPluginTensorRT
:
public
nvinfer1
::
IPluginV2DynamicExt
{
public:
public:
...
@@ -184,6 +293,7 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
...
@@ -184,6 +293,7 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
std
::
string
name_space_
;
std
::
string
name_space_
;
std
::
string
plugin_base_
;
std
::
string
plugin_base_
;
};
};
#endif
template
<
typename
T
>
template
<
typename
T
>
class
TrtPluginRegistrarV2
{
class
TrtPluginRegistrarV2
{
...
@@ -203,8 +313,6 @@ class TrtPluginRegistrarV2 {
...
@@ -203,8 +313,6 @@ class TrtPluginRegistrarV2 {
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {}
plugin_registrar_##name {}
#endif
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
python/setup.py.in
浏览文件 @
f4d9212d
...
@@ -336,6 +336,17 @@ if '${WITH_XPU_BKCL}' == 'ON':
...
@@ -336,6 +336,17 @@ if '${WITH_XPU_BKCL}' == 'ON':
shutil.copy('${XPU_BKCL_LIB}', libs_path)
shutil.copy('${XPU_BKCL_LIB}', libs_path)
package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}']
package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}']
# Only for lite xpu inference.
if '${WITH_XPU}' == 'OFF' and '${XPU_SDK_ROOT}' != '':
xpu_api_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/shlib/', 'libxpuapi.so')
xpu_rt_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/runtime/shlib/', 'libxpurt.so')
if os.path.exists(xpu_api_lib):
shutil.copy(xpu_api_lib, libs_path)
package_data['paddle.libs']+=['libxpuapi.so']
if os.path.exists(xpu_rt_lib):
shutil.copy(xpu_rt_lib, libs_path)
package_data['paddle.libs']+=['libxpurt.so']
### Old custom op extension mechanism related, will be removed in 2.1.0 ###
### Old custom op extension mechanism related, will be removed in 2.1.0 ###
# copy libpaddle_framework.so to libs on linux
# copy libpaddle_framework.so to libs on linux
if sys.platform.startswith('linux'):
if sys.platform.startswith('linux'):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录