Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2bdad6cd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2bdad6cd
编写于
12月 01, 2022
作者:
Z
Zhang Jun
提交者:
GitHub
12月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[inference][trt] Fp16 support for Generic plugin (#48253)
* Support FP16 in generic TensorRT plugin. * Support FP16 for Pad3D.
上级
9ffc760f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
127 addition
and
67 deletion
+127
-67
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+1
-12
paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc
...nce/tensorrt/convert/generic_and_custom_plugin_creater.cc
+4
-1
paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc
...nference/tensorrt/convert/multihead_matmul_roformer_op.cc
+1
-1
paddle/fluid/inference/tensorrt/helper.h
paddle/fluid/inference/tensorrt/helper.h
+10
-0
paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
+94
-47
paddle/fluid/inference/tensorrt/plugin/generic_plugin.h
paddle/fluid/inference/tensorrt/plugin/generic_plugin.h
+17
-6
未找到文件。
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
2bdad6cd
...
...
@@ -30,9 +30,7 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
...
...
@@ -42,15 +40,6 @@ namespace inference {
namespace
analysis
{
namespace
{
bool
IsFloat
(
framework
::
proto
::
VarType
::
Type
t
)
{
if
(
t
==
framework
::
proto
::
VarType
::
FP16
||
t
==
framework
::
proto
::
VarType
::
FP32
||
t
==
framework
::
proto
::
VarType
::
FP64
||
t
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
}
// if in mixed model precision, we should make all tensorrt_engine's output
// floats dtype to float32 dtype.
void
OutputProcess
(
framework
::
ir
::
Graph
*
graph
,
...
...
@@ -85,7 +74,7 @@ void OutputProcess(framework::ir::Graph *graph,
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
if
(
!
trt_outputs
.
count
(
var_node
))
continue
;
if
(
!
var_node
->
Var
()
->
Persistable
()
&&
IsFloat
(
var_node
->
Var
()
->
GetDataType
())
&&
tensorrt
::
IsFloatVar
(
var_node
->
Var
()
->
GetDataType
())
&&
var_node
->
Var
()
->
GetDataType
()
!=
framework
::
proto
::
VarType
::
FP32
)
{
for
(
auto
*
next_op
:
var_node
->
outputs
)
{
// if next_op support mixed_precision, we need to add cast op.
...
...
paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc
浏览文件 @
2bdad6cd
...
...
@@ -182,6 +182,8 @@ class GenericPluginCreater : public OpConverter {
phi
::
DefaultKernelSignatureMap
::
Instance
().
Get
(
op_desc
.
Type
());
}
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
GenericPlugin
::
InputOutPutVarInfo
in_out_info
;
for
(
auto
&
param_name
:
phi_kernel_signature
.
input_names
)
{
...
...
@@ -218,7 +220,8 @@ class GenericPluginCreater : public OpConverter {
in_out_info
.
outputs_data_type
.
push_back
(
var
->
GetDataType
());
}
}
plugin
::
GenericPlugin
*
plugin
=
new
plugin
::
GenericPlugin
(
op
,
in_out_info
);
plugin
::
GenericPlugin
*
plugin
=
new
plugin
::
GenericPlugin
(
op
,
in_out_info
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
inputs
.
size
(),
plugin
);
RreplenishLayerAndOutput
(
layer
,
op_desc
.
Type
(),
output_names
,
test_mode
);
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc
浏览文件 @
2bdad6cd
...
...
@@ -60,7 +60,7 @@ class MultiheadMatMulRoformerOpConverter : public OpConverter {
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
// (hidden_in, 3, hidden_out)
auto
weight_dims
=
weight_t
->
dims
();
auto
&
weight_dims
=
weight_t
->
dims
();
int
hidden_in
=
weight_dims
[
0
];
// channels_in
int
three
=
weight_dims
[
1
];
// channels_out
...
...
paddle/fluid/inference/tensorrt/helper.h
浏览文件 @
2bdad6cd
...
...
@@ -22,6 +22,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"
...
...
@@ -213,6 +214,15 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) {
}
return
nv_type
;
}
static
bool
IsFloatVar
(
framework
::
proto
::
VarType
::
Type
t
)
{
if
(
t
==
framework
::
proto
::
VarType
::
FP16
||
t
==
framework
::
proto
::
VarType
::
FP32
||
t
==
framework
::
proto
::
VarType
::
FP64
||
t
==
framework
::
proto
::
VarType
::
BF16
)
return
true
;
return
false
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
浏览文件 @
2bdad6cd
...
...
@@ -30,8 +30,11 @@ namespace plugin {
void
BuildPhiKernelContextAttr
(
const
framework
::
OpDesc
&
op_desc
,
phi
::
KernelContext
*
kernel_context
,
const
phi
::
KernelSignature
&
signature
,
const
phi
::
Kernel
&
phi_kernel
)
{
const
phi
::
KernelArgsDef
&
args_def
=
phi_kernel
.
args_def
();
const
phi
::
Kernel
*
phi_kernel
)
{
if
(
!
phi_kernel
->
IsValid
())
{
return
;
}
const
phi
::
KernelArgsDef
&
args_def
=
phi_kernel
->
args_def
();
const
auto
&
attr_names
=
signature
.
attr_names
;
const
auto
&
attr_defs
=
args_def
.
attribute_defs
();
...
...
@@ -221,28 +224,34 @@ void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
GenericPlugin
::
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
InputOutPutVarInfo
&
in_out_info
)
{
const
InputOutPutVarInfo
&
in_out_info
,
bool
with_fp16
)
{
proto_op_desc_
=
proto_op_desc
;
op_desc_
=
std
::
move
(
framework
::
OpDesc
(
proto_op_desc_
,
nullptr
));
proto_op_desc_
.
SerializeToString
(
&
op_meta_data_
);
inputs_data_type_
=
in_out_info
.
inputs_data_type
;
outputs_data_type_
=
in_out_info
.
outputs_data_type
;
with_fp16_
=
with_fp16
;
}
GenericPlugin
::
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
std
::
vector
<
int
>&
inputs_data_type
,
const
std
::
vector
<
int
>&
outputs_data_type
)
{
const
std
::
vector
<
int
>&
outputs_data_type
,
bool
with_fp16
)
{
proto_op_desc_
=
proto_op_desc
;
op_desc_
=
std
::
move
(
framework
::
OpDesc
(
proto_op_desc_
,
nullptr
));
proto_op_desc_
.
SerializeToString
(
&
op_meta_data_
);
inputs_data_type_
=
inputs_data_type
;
outputs_data_type_
=
outputs_data_type
;
with_fp16_
=
with_fp16
;
}
GenericPlugin
::
GenericPlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
inputs_data_type_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
outputs_data_type_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
std
::
string
op_meta_data
((
char
*
)(
serial_data
),
serial_length
);
// NOLINT
op_meta_data_
=
std
::
move
(
op_meta_data
);
proto_op_desc_
.
ParseFromString
(
op_meta_data_
);
...
...
@@ -266,8 +275,8 @@ int GenericPlugin::getNbInputs() const TRT_NOEXCEPT {
}
nvinfer1
::
IPluginV2DynamicExt
*
GenericPlugin
::
clone
()
const
TRT_NOEXCEPT
{
nvinfer1
::
IPluginV2DynamicExt
*
plugin
=
new
GenericPlugin
(
proto_op_desc_
,
inputs_data_type_
,
outputs_data_type
_
);
nvinfer1
::
IPluginV2DynamicExt
*
plugin
=
new
GenericPlugin
(
proto_op_desc_
,
inputs_data_type_
,
outputs_data_type_
,
with_fp16
_
);
plugin
->
initialize
();
return
plugin
;
}
...
...
@@ -277,6 +286,8 @@ void GenericPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
inputs_data_type_
);
// outputs_data_type_
SerializeValue
(
&
buffer
,
outputs_data_type_
);
// use fp16
SerializeValue
(
&
buffer
,
with_fp16_
);
// serialize op_meta_data_
std
::
memcpy
(
buffer
,
op_meta_data_
.
c_str
(),
op_meta_data_
.
size
());
reinterpret_cast
<
char
*&>
(
buffer
)
+=
op_meta_data_
.
size
();
...
...
@@ -310,6 +321,12 @@ bool GenericPlugin::supportsFormatCombination(
if
(
pos
==
3
)
return
(
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in_out
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
if
(
op_desc_
.
Type
()
==
"pad3d"
)
{
return
(
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
||
(
isFp16Supported
()
&&
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
))
&&
(
in_out
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
)
&&
(
in_out
[
0
].
type
==
in_out
[
pos
].
type
);
}
else
{
return
(
in_out
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in_out
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
...
...
@@ -337,34 +354,43 @@ int GenericPlugin::initialize() TRT_NOEXCEPT {
phi
::
DefaultKernelSignatureMap
::
Instance
().
Get
(
op_type
);
}
phi
::
KernelKey
phi_kernel_key
(
phi
::
Backend
::
GPU
,
phi
::
DataLayout
::
ANY
,
phi
::
DataType
::
FLOAT32
);
PADDLE_ENFORCE_EQ
(
phi
::
KernelFactory
::
Instance
().
HasCompatiblePhiKernel
(
op_type
),
true
,
platform
::
errors
::
Fatal
(
"%s has no compatible phi kernel!"
,
op_type
.
c_str
()));
const
phi
::
Kernel
&
phi_kernel
=
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
phi_kernel_signature
.
name
,
phi_kernel_key
);
phi_kernel_
=
&
phi_kernel
;
PADDLE_ENFORCE_EQ
(
phi_kernel_
->
IsValid
(),
true
,
platform
::
errors
::
Fatal
(
"%s phi kernel is invalid!."
,
phi_kernel_signature
.
name
));
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
platform
::
CUDAPlace
place
(
platform
::
GetCurrentDeviceId
());
auto
*
dev_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
pool
.
Get
(
place
));
if
(
!
phi_kernel_context_
)
{
phi_kernel_context_
=
new
phi
::
KernelContext
(
dev_ctx
);
BuildPhiKernelContextAttr
(
op_desc_
,
phi_kernel_context_
,
phi_kernel_signature
,
phi_kernel
);
std
::
vector
<
phi
::
DataType
>
precision_types
{
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
};
for
(
auto
&
precision_type
:
precision_types
)
{
phi
::
KernelKey
phi_kernel_key
(
phi
::
Backend
::
GPU
,
phi
::
DataLayout
::
ANY
,
precision_type
);
auto
nv_dtype
=
PhiType2NvType
(
precision_type
);
phi_kernels_
[
nv_dtype
].
reset
(
new
phi
::
Kernel
(
phi
::
KernelFactory
::
Instance
().
SelectKernel
(
phi_kernel_signature
.
name
,
phi_kernel_key
)));
if
(
phi_kernel_contexts_
.
find
(
nv_dtype
)
==
phi_kernel_contexts_
.
end
()
||
!
phi_kernel_contexts_
[
nv_dtype
])
{
phi_kernel_contexts_
[
nv_dtype
].
reset
(
new
phi
::
KernelContext
(
dev_ctx
));
BuildPhiKernelContextAttr
(
op_desc_
,
phi_kernel_contexts_
[
nv_dtype
].
get
(),
phi_kernel_signature
,
phi_kernels_
[
nv_dtype
].
get
());
}
}
PADDLE_ENFORCE_EQ
(
phi_kernels_
[
nvinfer1
::
DataType
::
kFLOAT
]
->
IsValid
()
||
phi_kernels_
[
nvinfer1
::
DataType
::
kHALF
]
->
IsValid
(),
true
,
platform
::
errors
::
Fatal
(
"%s phi kernel is invalid!."
,
phi_kernel_signature
.
name
));
if
(
!
dense_tensor_inputs_
)
dense_tensor_inputs_
=
new
std
::
vector
<
phi
::
DenseTensor
>
(
getNbInputs
());
if
(
!
dense_tensor_outputs_
)
...
...
@@ -396,15 +422,14 @@ void GenericPlugin::configurePlugin(
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
{
CHECK
(
phi_kernel
_context_
);
CHECK
(
phi_kernel_
);
CHECK
(
phi_kernel
s_
[
nvinfer1
::
DataType
::
kFLOAT
]
->
IsValid
()
||
phi_kernels_
[
nvinfer1
::
DataType
::
kHALF
]
->
IsValid
()
);
CHECK
(
nb_inputs
==
getNbInputs
());
CHECK
(
nb_outputs
==
getNbOutputs
());
}
// Shutdown the layer. This is called when the engine is destroyed
void
GenericPlugin
::
terminate
()
TRT_NOEXCEPT
{
delete
phi_kernel_context_
;
delete
dense_tensor_inputs_
;
delete
dense_tensor_outputs_
;
}
...
...
@@ -418,27 +443,42 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
platform
::
CUDAPlace
place
(
platform
::
GetCurrentDeviceId
());
// [TODO]now generic plugin do not support FP16 and INT8 precision
auto
protoType2PhiType
=
[](
int
proto_type
)
->
std
::
pair
<
phi
::
DataType
,
int
>
{
auto
protoType2PhiType
=
[
&
](
int
proto_type
,
nvinfer1
::
DataType
nv_dtype
)
->
std
::
pair
<
phi
::
DataType
,
int
>
{
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_FP32
))
return
{
phi
::
DataType
::
FLOAT32
,
sizeof
(
float
)};
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_INT64
)
||
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_INT32
))
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_FP16
))
{
return
{
phi
::
DataType
::
FLOAT16
,
sizeof
(
half
)};
}
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_FP32
))
{
if
(
isFp16Supported
()
&&
nv_dtype
==
nvinfer1
::
DataType
::
kHALF
)
{
return
{
phi
::
DataType
::
FLOAT16
,
sizeof
(
half
)};
}
else
{
return
{
phi
::
DataType
::
FLOAT32
,
sizeof
(
float
)};
}
}
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_INT64
)
||
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_INT32
))
{
return
{
phi
::
DataType
::
INT32
,
sizeof
(
int32_t
)};
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_BOOL
))
}
else
if
(
proto_type
==
static_cast
<
int
>
(
framework
::
proto
::
VarType_Type
::
VarType_Type_BOOL
))
{
return
{
phi
::
DataType
::
BOOL
,
sizeof
(
bool
)};
else
}
else
{
CHECK
(
false
)
<<
"precision is not supported"
;
}
};
// input
phi_kernel_context_
->
ClearInputOutput
();
auto
data_type
=
input_desc
[
0
].
type
;
CHECK
((
data_type
==
nvinfer1
::
DataType
::
kFLOAT
)
||
(
data_type
==
nvinfer1
::
DataType
::
kHALF
));
phi_kernel_contexts_
[
data_type
]
->
ClearInputOutput
();
for
(
int
i
=
0
;
i
<
getNbInputs
();
i
++
)
{
auto
const
&
input_dims
=
input_desc
[
i
].
dims
;
...
...
@@ -450,7 +490,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
int
input_numel
=
1
;
for
(
int
k
=
0
;
k
<
input_shape
.
size
();
k
++
)
input_numel
*=
input_shape
[
k
];
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
]);
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
],
data_type
);
phi
::
DenseTensorMeta
input_meta
(
data_type_and_size
.
first
,
phi
::
make_ddim
(
input_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
input_alloc
(
...
...
@@ -459,9 +501,9 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
place
));
(
*
dense_tensor_inputs_
)[
i
]
=
std
::
move
(
phi
::
DenseTensor
(
input_alloc
,
input_meta
));
phi_kernel_context_
->
EmplaceBackInput
(
&
((
*
dense_tensor_inputs_
)[
i
]));
phi_kernel_contexts_
[
data_type
]
->
EmplaceBackInput
(
&
((
*
dense_tensor_inputs_
)[
i
]));
}
// output
for
(
int
i
=
0
;
i
<
getNbOutputs
();
i
++
)
{
auto
const
&
output_dims
=
output_desc
[
i
].
dims
;
...
...
@@ -474,23 +516,28 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
for
(
int
k
=
0
;
k
<
output_shape
.
size
();
k
++
)
output_numel
*=
output_shape
[
k
];
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
]);
auto
data_type_and_size
=
protoType2PhiType
(
inputs_data_type_
[
i
],
data_type
);
phi
::
DenseTensorMeta
output_meta
(
data_type_and_size
.
first
,
phi
::
make_ddim
(
output_shape
));
std
::
shared_ptr
<
phi
::
Allocation
>
output_alloc
(
new
phi
::
Allocation
(
reinterpret_cast
<
void
*>
(
outputs
[
i
]),
output_numel
*
data_type_and_size
.
second
,
place
));
phi
::
DenseTensor
output_densetonsor
(
output_alloc
,
output_meta
);
(
*
dense_tensor_outputs_
)[
i
]
=
std
::
move
(
phi
::
DenseTensor
(
output_alloc
,
output_meta
));
phi_kernel_context_
->
EmplaceBackOutput
(
&
((
*
dense_tensor_outputs_
)[
i
]));
phi_kernel_contexts_
[
data_type
]
->
EmplaceBackOutput
(
&
((
*
dense_tensor_outputs_
)[
i
]));
}
CHECK_EQ
(
phi_kernel_context
_
->
InputsSize
(),
getNbInputs
());
CHECK_EQ
(
phi_kernel_context
_
->
OutputsSize
(),
getNbOutputs
());
CHECK_EQ
(
phi_kernel_context
s_
[
data_type
]
->
InputsSize
(),
getNbInputs
());
CHECK_EQ
(
phi_kernel_context
s_
[
data_type
]
->
OutputsSize
(),
getNbOutputs
());
(
*
phi_kernel
_
)(
phi_kernel_context_
);
(
*
phi_kernel
s_
[
data_type
])(
phi_kernel_contexts_
[
data_type
].
get
()
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
paddle/fluid/inference/tensorrt/plugin/generic_plugin.h
浏览文件 @
2bdad6cd
...
...
@@ -44,7 +44,7 @@ namespace plugin {
void
BuildPhiKernelContextAttr
(
const
framework
::
OpDesc
&
op_desc
,
phi
::
KernelContext
*
kernel_context
,
const
phi
::
KernelSignature
&
signature
,
const
phi
::
Kernel
&
phi_kernel
);
const
phi
::
Kernel
*
phi_kernel
);
class
GenericPlugin
:
public
DynamicPluginTensorRT
{
public:
...
...
@@ -57,11 +57,13 @@ class GenericPlugin : public DynamicPluginTensorRT {
GenericPlugin
()
{}
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
InputOutPutVarInfo
&
in_out_info
);
const
InputOutPutVarInfo
&
in_out_info
,
bool
with_fp16_
=
false
);
GenericPlugin
(
const
paddle
::
framework
::
proto
::
OpDesc
&
proto_op_desc
,
const
std
::
vector
<
int
>&
inputs_data_type
,
const
std
::
vector
<
int
>&
outputs_data_type
);
const
std
::
vector
<
int
>&
outputs_data_type
,
bool
with_fp16_
=
false
);
// It was used for tensorrt deserialization.
// It should not be called by users.
...
...
@@ -86,7 +88,7 @@ class GenericPlugin : public DynamicPluginTensorRT {
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
op_meta_data_
.
size
()
+
SerializedSize
(
inputs_data_type_
)
+
SerializedSize
(
outputs_data_type_
);
SerializedSize
(
outputs_data_type_
)
+
SerializedSize
(
with_fp16_
)
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
;
...
...
@@ -122,15 +124,24 @@ class GenericPlugin : public DynamicPluginTensorRT {
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
;
bool
isFp16Supported
()
{
auto
half_dtype
=
nvinfer1
::
DataType
::
kHALF
;
return
with_fp16_
&&
!
(
phi_kernels_
.
find
(
half_dtype
)
==
phi_kernels_
.
end
())
&&
phi_kernels_
[
half_dtype
]
->
IsValid
();
}
private:
std
::
string
op_meta_data_
;
framework
::
proto
::
OpDesc
proto_op_desc_
;
framework
::
OpDesc
op_desc_
;
private:
const
phi
::
Kernel
*
phi_kernel_
{
nullptr
};
std
::
unordered_map
<
nvinfer1
::
DataType
,
std
::
unique_ptr
<
phi
::
Kernel
>>
phi_kernels_
;
std
::
unordered_map
<
nvinfer1
::
DataType
,
std
::
unique_ptr
<
phi
::
KernelContext
>>
phi_kernel_contexts_
;
phi
::
KernelContext
*
phi_kernel_context_
{
nullptr
};
std
::
vector
<
phi
::
DenseTensor
>*
dense_tensor_inputs_
{
nullptr
};
std
::
vector
<
phi
::
DenseTensor
>*
dense_tensor_outputs_
{
nullptr
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录