Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2bdad6cd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2bdad6cd
编写于
12月 01, 2022
作者:
Z
Zhang Jun
提交者:
GitHub
12月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[inference][trt] Fp16 support for Generic plugin (#48253)
* Support FP16 in generic TensorRT plugin. * Support FP16 for Pad3D.
上级
9ffc760f
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
127 addition
and
67 deletion
+127
-67
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+1
-12
paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc
...nce/tensorrt/convert/generic_and_custom_plugin_creater.cc
+4
-1
paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc
...nference/tensorrt/convert/multihead_matmul_roformer_op.cc
+1
-1
paddle/fluid/inference/tensorrt/helper.h
paddle/fluid/inference/tensorrt/helper.h
+10
-0
paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
+94
-47
paddle/fluid/inference/tensorrt/plugin/generic_plugin.h
paddle/fluid/inference/tensorrt/plugin/generic_plugin.h
+17
-6
未找到文件。
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
2bdad6cd
...
...
@@ -30,9 +30,7 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
...
...
@@ -42,15 +40,6 @@ namespace inference {
namespace
analysis
{
namespace
{
bool
IsFloat
(
framework
::
proto
::
VarType
::
Type
t
)
{
if
(
t
==
framework
::
proto
::
VarType
::
FP16
||
t
==
framework
::
proto
::
VarType
::
FP32
||
t
==
framework
::
proto
::
VarType
::
FP64
||
t
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
}
// if in mixed model precision, we should make all tensorrt_engine's output
// floats dtype to float32 dtype.
void
OutputProcess
(
framework
::
ir
::
Graph
*
graph
,
...
...
@@ -85,7 +74,7 @@ void OutputProcess(framework::ir::Graph *graph,
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
trt_outputs
.
count
(
var_node
))
continue
;
if
(
!
var_node
->
Var
()
->
Persistable
()
&&
IsFloat
(
var_node
->
Var
()
->
GetDataType
())
&&
tensorrt
::
IsFloatVar
(
var_node
->
Var
()
->
GetDataType
())
&&
var_node
->
Var
()
->
GetDataType
()
!=
framework
::
proto
::
VarType
::
FP32
)
{
for
(
auto
*
next_op
:
var_node
->
outputs
)
{
// if next_op support mixed_precision, we need to add cast op.
...
...
paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc
浏览文件 @
2bdad6cd
...
...
@@ -182,6 +182,8 @@ class GenericPluginCreater : public OpConverter {
phi
::
DefaultKernelSignatureMap
::
Instance
().
Get
(
op_desc
.
Type
());
}
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
GenericPlugin
::
InputOutPutVarInfo
in_out_info
;
for
(
auto
&
param_name
:
phi_kernel_signature
.
input_names
)
{
...
...
@@ -218,7 +220,8 @@ class GenericPluginCreater : public OpConverter {
in_out_info
.
outputs_data_type
.
push_back
(
var
->
GetDataType
());
}
}
plugin
::
GenericPlugin
*
plugin
=
new
plugin
::
GenericPlugin
(
op
,
in_out_info
);
plugin
::
GenericPlugin
*
plugin
=
new
plugin
::
GenericPlugin
(
op
,
in_out_info
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
inputs
.
size
(),
plugin
);
RreplenishLayerAndOutput
(
layer
,
op_desc
.
Type
(),
output_names
,
test_mode
);
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc
浏览文件 @
2bdad6cd
...
...
@@ -60,7 +60,7 @@ class MultiheadMatMulRoformerOpConverter : public OpConverter {
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
// (hidden_in, 3, hidden_out)
auto
weight_dims
=
weight_t
->
dims
();
auto
&
weight_dims
=
weight_t
->
dims
();
int
hidden_in
=
weight_dims
[
0
];
// channels_in
int
three
=
weight_dims
[
1
];
// channels_out
...
...
paddle/fluid/inference/tensorrt/helper.h
浏览文件 @
2bdad6cd
...
...
@@ -22,6 +22,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"
...
...
@@ -213,6 +214,15 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) {
}
return
nv_type
;
}
static
bool
IsFloatVar
(
framework
::
proto
::
VarType
::
Type
t
)
{
if
(
t
==
framework
::
proto
::
VarType
::
FP16
||
t
==
framework
::
proto
::
VarType
::
FP32
||
t
==
framework
::
proto
::
VarType
::
FP64
||
t
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
浏览文件 @
2bdad6cd
...
...
@@ -30,8 +30,11 @@ namespace plugin {
void
BuildPhiKernelContextAttr
(
const
framework
::
OpDesc
&
op_desc
,
phi
::
KernelContext
*
kernel_context
,
const
phi
::
KernelSignature
&
signature
,
const
phi
::
Kernel
&
phi_kernel
)
{
const
phi
::
KernelArgsDef
&
args_def
=
phi_kernel
.
args_def
();
const
phi
::
Kernel
*
phi_kernel
)
{
if
(
!
phi_kernel
->
IsValid
())
{
return
;
}
const
phi
::
KernelArgsDef
&
args_def
=
phi_kernel
->
args_def
();
const
auto
&
attr_names
=
signature
.
attr_names
;
const
auto
&
attr_defs
=
args_def
.
attribute_defs
();
...
...
@@ -221,28 +224,34 @@ void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
GenericPlugin
::
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
InputOutPutVarInfo
&
in_out_info
)
{
const
InputOutPutVarInfo
&
in_out_info
,
bool
with_fp16
)
{
proto_op_desc_
=
proto_op_desc
;
op_desc_
=
std
::
move
(
framework
::
OpDesc
(
proto_op_desc_
,
nullptr
));
proto_op_desc_
.
SerializeToString
(
&
op_meta_data_
);
inputs_data_type_
=
in_out_info
.
inputs_data_type
;
outputs_data_type_
=
in_out_info
.
outputs_data_type
;
with_fp16_
=
with_fp16
;
}
GenericPlugin
::
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
std
::
vector
<
int
>&
inputs_data_type
,
const
std
::
vector
<
int
>&
outputs_data_type
)
{
const
std
::
vector
<
int
>&
outputs_data_type
,
bool
with_fp16
)
{
proto_op_desc_
=
proto_op_desc
;
op_desc_
=
std
::
move
(
framework
::
OpDesc
(
proto_op_desc_
,
nullptr
));
proto_op_desc_
.
SerializeToString
(
&
op_meta_data_
);
inputs_data_type_
=
inputs_data_type
;
outputs_data_type_
=
outputs_data_type
;
with_fp16_
=
with_fp16
;
}
GenericPlugin
::
GenericPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
inputs_data_type_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
outputs_data_type_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
std
::
string
op_meta_data
((
char
*
)(
serial_data
),
serial_length
);
// NOLINT
op_meta_data_
=
std
::
move
(
op_meta_data
);
proto_op_desc_
.
ParseFromString
(
op_meta_data_
);
...
...
@@ -266,8 +275,8 @@ int GenericPlugin::getNbInputs() const TRT_NOEXCEPT {
}
nvinfer1
::
IPluginV2DynamicExt
*
GenericPlugin
::
clone
()
const
TRT_NOEXCEPT
{
nvinfer1
::
IPluginV2DynamicExt
*
plugin
=
new
GenericPlugin
(
proto_op_desc_
,
inputs_data_type_
,
outputs_data_type
_
);
nvinfer1
::
IPluginV2DynamicExt
*
plugin
=
new
GenericPlugin
(
proto_op_desc_
,
inputs_data_type_
,
outputs_data_type_
,
with_fp16
_
);
plugin
->
initialize
();
return
plugin
;
}
...
...
@@ -277,6 +286,8 @@ void GenericPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
inputs_data_type_
);
// outputs_data_type_
SerializeValue
(
&
buffer
,
outputs_data_type_
);
// use fp16
SerializeValue
(
&
buffer
,
with_fp16_
);
// serialize op_meta_data_
std
::
memcpy
(
buffer
,
op_meta_data_
.
c_str
(),
op_meta_data_
.
size
());
reinterpret_cast
<
char
*&>
(
buffer
)
+=
op_meta_data_
.
size
();
...
...
@@ -310,6 +321,12 @@ bool GenericPlugin::supportsFormatCombination(
if
(
pos
==
3
)
return
(
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in_out
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
if
(
op_desc_
.
Type
()
==
"pad3d"
)
{
return
(
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
||
(
isFp16Supported
()
&&
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
))
&&
(
in_out
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
)
&&
(
in_out
[
0
].
type
==
in_out
[
pos
].
type
);
}
else
{
return
(
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in_out
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
...
...
@@ -337,34 +354,43 @@ int GenericPlugin::initialize() TRT_NOEXCEPT {
phi
::
DefaultKernelSignatureMap
::
Instance
().
Get
(
op_type
);
}
phi
::
KernelKey
phi_kernel_key
(
phi
::
Backend
::
GPU
,
phi
::
DataLayout
::
ANY
,
phi
::
DataType
::
FLOAT32
);
PADDLE_ENFORCE_EQ
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
op_type
),
true
,
platform
::
errors
::
Fatal
(
"%s has no compatible phi kernel!"
,
op_type
.
c_str
()));
const
phi
::
Kernel
&
phi_kernel
=
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
phi_kernel_signature
.
name
,
phi_kernel_key
);
phi_kernel_
=
&
phi_kernel
;
PADDLE_ENFORCE_EQ
(
phi_kernel_
->
IsValid
(),
true
,
platform
::
errors
::
Fatal
(
"%s phi kernel is invalid!."
,
phi_kernel_signature
.
name
));
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
platform
::
CUDAPlace
place
(
platform
::
GetCurrentDeviceId
());
auto
*
dev_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
pool
.
Get
(
place
));
if
(
!
phi_kernel_context_
)
{
phi_kernel_context_
=
new
phi
::
KernelContext
(
dev_ctx
);
BuildPhiKernelContextAttr
(
op_desc_
,
phi_kernel_context_
,
phi_kernel_signature
,
phi_kernel
);
std
::
vector
<
phi
::
DataType
>
precision_types
{
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
};
for
(
auto
&
precision_type
:
precision_types
)
{
phi
::
KernelKey
phi_kernel_key
(
phi
::
Backend
::
GPU
,
phi
::
DataLayout
::
ANY
,
precision_type
);
auto
nv_dtype
=
PhiType2NvType
(
precision_type
);
phi_kernels_
[
nv_dtype
].
reset
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
phi_kernel_signature
.
name
,
phi_kernel_key
)));
if
(
phi_kernel_contexts_
.
find
(
nv_dtype
)
==
phi_kernel_contexts_
.
end
()
||
!
phi_kernel_contexts_
[
nv_dtype
])
{
phi_kernel_contexts_
[
nv_dtype
].
reset
(
new
phi
::
KernelContext
(
dev_ctx
));
BuildPhiKernelContextAttr
(
op_desc_
,
phi_kernel_contexts_
[
nv_dtype
].
get
(),
phi_kernel_signature
,
phi_kernels_
[
nv_dtype
].
get
());
}
}
PADDLE_ENFORCE_EQ
(
phi_kernels_
[
nvinfer1
::
DataType
::
kFLOAT
]
->
IsValid
()
||
phi_kernels_
[
nvinfer1
::
DataType
::
kHALF
]
->
IsValid
(),
true
,
platform
::
errors
::
Fatal
(
"%s phi kernel is invalid!."
,
phi_kernel_signature
.
name
));
if
(
!
dense_tensor_inputs_
)
dense_tensor_inputs_
=
new
std
::
vector
<
phi
::
DenseTensor
>
(
getNbInputs
());
if
(
!
dense_tensor_outputs_
)
...
...
@@ -396,15 +422,14 @@ void GenericPlugin::configurePlugin(
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
{
CHECK
(
phi_kernel
_context_
);
CHECK
(
phi_kernel_
);
CHECK
(
phi_kernel
s_
[
nvinfer1
::
DataType
::
kFLOAT
]
->
IsValid
()
||
phi_kernels_
[
nvinfer1
::
DataType
::
kHALF
]
->
IsValid
()
);
CHECK
(
nb_inputs
==
getNbInputs
());
CHECK
(
nb_outputs
==
getNbOutputs
());
}
// Shutdown the layer. This is called when the engine is destroyed
void
GenericPlugin
::
terminate
()
TRT_NOEXCEPT
{
delete
phi_kernel_context_
;
delete
dense_tensor_inputs_
;
delete
dense_tensor_outputs_
;
}
...
...
@@ -418,27 +443,42 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
platform
::
CUDAPlace
place
(
platform
::
GetCurrentDeviceId
());
// [TODO]now generic plugin do not support FP16 and INT8 precision
auto
protoType2PhiType
=
[](
int
proto_type
)
->
std
::
pair
<
phi
::
DataType
,
int
>
{
auto
protoType2PhiType
=
[
&
](
int
proto_type
,
nvinfer1
::
DataType
nv_dtype
)
->
std
::
pair
<
phi
::
DataType
,
int
>
{
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_FP32
))
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_FP16
))
{
return
{
phi
::
DataType
::
FLOAT16
,
sizeof
(
half
)};
}
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_FP32
))
{
if
(
isFp16Supported
()
&&
nv_dtype
==
nvinfer1
::
DataType
::
kHALF
)
{
return
{
phi
::
DataType
::
FLOAT16
,
sizeof
(
half
)};
}
else
{
return
{
phi
::
DataType
::
FLOAT32
,
sizeof
(
float
)};
else
if
(
proto_type
==
}
}
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_INT64
)
||
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_INT32
))
framework
::
proto
::
VarType_Type
::
VarType_Type_INT32
))
{
return
{
phi
::
DataType
::
INT32
,
sizeof
(
int32_t
)};
else
if
(
proto_type
==
}
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_BOOL
))
framework
::
proto
::
VarType_Type
::
VarType_Type_BOOL
))
{
return
{
phi
::
DataType
::
BOOL
,
sizeof
(
bool
)};
else
}
else
{
CHECK
(
false
)
<<
"precision is not supported"
;
}
};
// input
phi_kernel_context_
->
ClearInputOutput
();
auto
data_type
=
input_desc
[
0
].
type
;
CHECK
((
data_type
==
nvinfer1
::
DataType
::
kFLOAT
)
||
(
data_type
==
nvinfer1
::
DataType
::
kHALF
));
phi_kernel_contexts_
[
data_type
]
->
ClearInputOutput
();
for
(
int
i
=
0
;
i
<
getNbInputs
();
i
++
)
{
auto
const
&
input_dims
=
input_desc
[
i
].
dims
;
...
...
@@ -450,7 +490,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
int
input_numel
=
1
;
for
(
int
k
=
0
;
k
<
input_shape
.
size
();
k
++
)
input_numel
*=
input_shape
[
k
];
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
]);
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
],
data_type
);
phi
::
DenseTensorMeta
input_meta
(
data_type_and_size
.
first
,
phi
::
make_ddim
(
input_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
input_alloc
(
...
...
@@ -459,9 +501,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
place
));
(
*
dense_tensor_inputs_
)[
i
]
=
std
::
move
(
phi
::
DenseTensor
(
input_alloc
,
input_meta
));
phi_kernel_context_
->
EmplaceBackInput
(
&
((
*
dense_tensor_inputs_
)[
i
]));
phi_kernel_contexts_
[
data_type
]
->
EmplaceBackInput
(
&
((
*
dense_tensor_inputs_
)[
i
]));
}
// output
for
(
int
i
=
0
;
i
<
getNbOutputs
();
i
++
)
{
auto
const
&
output_dims
=
output_desc
[
i
].
dims
;
...
...
@@ -474,23 +516,28 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
for
(
int
k
=
0
;
k
<
output_shape
.
size
();
k
++
)
output_numel
*=
output_shape
[
k
];
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
]);
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
],
data_type
);
phi
::
DenseTensorMeta
output_meta
(
data_type_and_size
.
first
,
phi
::
make_ddim
(
output_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
output_alloc
(
new
phi
::
Allocation
(
reinterpret_cast
<
void
*>
(
outputs
[
i
]),
output_numel
*
data_type_and_size
.
second
,
place
));
phi
::
DenseTensor
output_densetonsor
(
output_alloc
,
output_meta
);
(
*
dense_tensor_outputs_
)[
i
]
=
std
::
move
(
phi
::
DenseTensor
(
output_alloc
,
output_meta
));
phi_kernel_context_
->
EmplaceBackOutput
(
&
((
*
dense_tensor_outputs_
)[
i
]));
phi_kernel_contexts_
[
data_type
]
->
EmplaceBackOutput
(
&
((
*
dense_tensor_outputs_
)[
i
]));
}
CHECK_EQ
(
phi_kernel_context
_
->
InputsSize
(),
getNbInputs
());
CHECK_EQ
(
phi_kernel_context
_
->
OutputsSize
(),
getNbOutputs
());
CHECK_EQ
(
phi_kernel_context
s_
[
data_type
]
->
InputsSize
(),
getNbInputs
());
CHECK_EQ
(
phi_kernel_context
s_
[
data_type
]
->
OutputsSize
(),
getNbOutputs
());
(
*
phi_kernel
_
)(
phi_kernel_context_
);
(
*
phi_kernel
s_
[
data_type
])(
phi_kernel_contexts_
[
data_type
].
get
()
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
paddle/fluid/inference/tensorrt/plugin/generic_plugin.h
浏览文件 @
2bdad6cd
...
...
@@ -44,7 +44,7 @@ namespace plugin {
void
BuildPhiKernelContextAttr
(
const
framework
::
OpDesc
&
op_desc
,
phi
::
KernelContext
*
kernel_context
,
const
phi
::
KernelSignature
&
signature
,
const
phi
::
Kernel
&
phi_kernel
);
const
phi
::
Kernel
*
phi_kernel
);
class
GenericPlugin
:
public
DynamicPluginTensorRT
{
public:
...
...
@@ -57,11 +57,13 @@ class GenericPlugin : public DynamicPluginTensorRT {
GenericPlugin
()
{}
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
InputOutPutVarInfo
&
in_out_info
);
const
InputOutPutVarInfo
&
in_out_info
,
bool
with_fp16_
=
false
);
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
std
::
vector
<
int
>&
inputs_data_type
,
const
std
::
vector
<
int
>&
outputs_data_type
);
const
std
::
vector
<
int
>&
outputs_data_type
,
bool
with_fp16_
=
false
);
// It was used for tensorrt deserialization.
// It should not be called by users.
...
...
@@ -86,7 +88,7 @@ class GenericPlugin : public DynamicPluginTensorRT {
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
op_meta_data_
.
size
()
+
SerializedSize
(
inputs_data_type_
)
+
SerializedSize
(
outputs_data_type_
);
SerializedSize
(
outputs_data_type_
)
+
SerializedSize
(
with_fp16_
)
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
;
...
...
@@ -122,15 +124,24 @@ class GenericPlugin : public DynamicPluginTensorRT {
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
;
bool
isFp16Supported
()
{
auto
half_dtype
=
nvinfer1
::
DataType
::
kHALF
;
return
with_fp16_
&&
!
(
phi_kernels_
.
find
(
half_dtype
)
==
phi_kernels_
.
end
())
&&
phi_kernels_
[
half_dtype
]
->
IsValid
();
}
private:
std
::
string
op_meta_data_
;
framework
::
proto
::
OpDesc
proto_op_desc_
;
framework
::
OpDesc
op_desc_
;
private:
const
phi
::
Kernel
*
phi_kernel_
{
nullptr
};
std
::
unordered_map
<
nvinfer1
::
DataType
,
std
::
unique_ptr
<
phi
::
Kernel
>>
phi_kernels_
;
std
::
unordered_map
<
nvinfer1
::
DataType
,
std
::
unique_ptr
<
phi
::
KernelContext
>>
phi_kernel_contexts_
;
phi
::
KernelContext
*
phi_kernel_context_
{
nullptr
};
std
::
vector
<
phi
::
DenseTensor
>*
dense_tensor_inputs_
{
nullptr
};
std
::
vector
<
phi
::
DenseTensor
>*
dense_tensor_outputs_
{
nullptr
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录