Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0b962680
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0b962680
编写于
11月 13, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix comments
test=develop
上级
e5bf8616
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
64 addition
and
60 deletion
+64
-60
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+3
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+1
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+1
-1
paddle/fluid/inference/tensorrt/plugin/serialize.h
paddle/fluid/inference/tensorrt/plugin/serialize.h
+26
-26
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+0
-3
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+10
-13
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+10
-10
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+13
-5
未找到文件。
paddle/fluid/inference/tensorrt/convert/split_op.cc
浏览文件 @
0b962680
...
@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter {
...
@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter {
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
// Get Attrs
PADDLE_ENFORCE
(
input_num
==
1
);
PADDLE_ENFORCE
(
input_num
==
1
);
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
std
::
vector
<
int
>
output_lengths
=
std
::
vector
<
int
>
output_lengths
=
...
@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter {
...
@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter {
PADDLE_ENFORCE
(
output_lengths
.
size
()
==
output_num
);
PADDLE_ENFORCE
(
output_lengths
.
size
()
==
output_num
);
//
SplitPlugin
*
plugin
=
new
SplitPlugin
(
axis
,
output_lengths
);
SplitPlugin
*
plugin
=
new
SplitPlugin
(
axis
,
output_lengths
);
nvinfer1
::
IPluginLayer
*
layer
=
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
a
ddPlugin
(
&
input
,
input_num
,
plugin
);
engine_
->
A
ddPlugin
(
&
input
,
input_num
,
plugin
);
std
::
string
layer_name
=
"split (Output: "
;
std
::
string
layer_name
=
"split (Output: "
;
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
0b962680
...
@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() {
...
@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice
(
device_
);
cudaSetDevice
(
device_
);
}
}
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
a
ddPlugin
(
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
A
ddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
plugin
)
{
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
plugin
)
{
owned_plugin_
.
emplace_back
(
plugin
);
owned_plugin_
.
emplace_back
(
plugin
);
return
infer_network_
.
get
()
->
addPluginExt
(
inputs
,
nbInputs
,
*
plugin
);
return
infer_network_
.
get
()
->
addPluginExt
(
inputs
,
nbInputs
,
*
plugin
);
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
0b962680
...
@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase {
...
@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase {
void
SetRuntimeBatch
(
size_t
batch_size
);
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
int
GetDevice
()
{
return
device_
;
}
nvinfer1
::
IPluginLayer
*
a
ddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
nvinfer1
::
IPluginLayer
*
A
ddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
);
int
nbInputs
,
PluginTensorRT
*
);
// A pointer to CPU memory is needed of the TRT weight.
// A pointer to CPU memory is needed of the TRT weight.
...
...
paddle/fluid/inference/tensorrt/plugin/serialize.h
浏览文件 @
0b962680
...
@@ -20,10 +20,10 @@
...
@@ -20,10 +20,10 @@
#include <vector>
#include <vector>
template
<
typename
T
>
template
<
typename
T
>
inline
void
serialize_v
alue
(
void
**
buffer
,
T
const
&
value
);
inline
void
SerializeV
alue
(
void
**
buffer
,
T
const
&
value
);
template
<
typename
T
>
template
<
typename
T
>
inline
void
deserialize_v
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
inline
void
DeserializeV
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
);
T
*
value
);
namespace
{
namespace
{
...
@@ -35,14 +35,14 @@ template <typename T>
...
@@ -35,14 +35,14 @@ template <typename T>
struct
Serializer
<
T
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
struct
Serializer
<
T
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
serialized_s
ize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
size_t
SerializedS
ize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
void
s
erialize
(
void
**
buffer
,
T
const
&
value
)
{
static
void
S
erialize
(
void
**
buffer
,
T
const
&
value
)
{
::
memcpy
(
*
buffer
,
&
value
,
sizeof
(
T
));
std
::
memcpy
(
*
buffer
,
&
value
,
sizeof
(
T
));
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
sizeof
(
T
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
sizeof
(
T
);
}
}
static
void
d
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
static
void
D
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
assert
(
*
buffer_size
>=
sizeof
(
T
));
assert
(
*
buffer_size
>=
sizeof
(
T
));
::
memcpy
(
value
,
*
buffer
,
sizeof
(
T
));
std
::
memcpy
(
value
,
*
buffer
,
sizeof
(
T
));
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
sizeof
(
T
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
sizeof
(
T
);
*
buffer_size
-=
sizeof
(
T
);
*
buffer_size
-=
sizeof
(
T
);
}
}
...
@@ -50,12 +50,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
...
@@ -50,12 +50,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
template
<
>
template
<
>
struct
Serializer
<
const
char
*>
{
struct
Serializer
<
const
char
*>
{
static
size_t
serialized_s
ize
(
const
char
*
value
)
{
return
strlen
(
value
)
+
1
;
}
static
size_t
SerializedS
ize
(
const
char
*
value
)
{
return
strlen
(
value
)
+
1
;
}
static
void
s
erialize
(
void
**
buffer
,
const
char
*
value
)
{
static
void
S
erialize
(
void
**
buffer
,
const
char
*
value
)
{
::
strcpy
(
static_cast
<
char
*>
(
*
buffer
),
value
);
std
::
strcpy
(
static_cast
<
char
*>
(
*
buffer
),
value
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
strlen
(
value
)
+
1
;
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
strlen
(
value
)
+
1
;
}
}
static
void
d
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
static
void
D
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
const
char
**
value
)
{
const
char
**
value
)
{
*
value
=
static_cast
<
char
const
*>
(
*
buffer
);
*
value
=
static_cast
<
char
const
*>
(
*
buffer
);
size_t
data_size
=
strnlen
(
*
value
,
*
buffer_size
)
+
1
;
size_t
data_size
=
strnlen
(
*
value
,
*
buffer_size
)
+
1
;
...
@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>,
...
@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
serialized_s
ize
(
std
::
vector
<
T
>
const
&
value
)
{
static
size_t
SerializedS
ize
(
std
::
vector
<
T
>
const
&
value
)
{
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
}
}
static
void
s
erialize
(
void
**
buffer
,
std
::
vector
<
T
>
const
&
value
)
{
static
void
S
erialize
(
void
**
buffer
,
std
::
vector
<
T
>
const
&
value
)
{
serialize_v
alue
(
buffer
,
value
.
size
());
SerializeV
alue
(
buffer
,
value
.
size
());
size_t
nbyte
=
value
.
size
()
*
sizeof
(
T
);
size_t
nbyte
=
value
.
size
()
*
sizeof
(
T
);
::
memcpy
(
*
buffer
,
value
.
data
(),
nbyte
);
std
::
memcpy
(
*
buffer
,
value
.
data
(),
nbyte
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
nbyte
;
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
nbyte
;
}
}
static
void
d
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
static
void
D
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
std
::
vector
<
T
>*
value
)
{
std
::
vector
<
T
>*
value
)
{
size_t
size
;
size_t
size
;
deserialize_v
alue
(
buffer
,
buffer_size
,
&
size
);
DeserializeV
alue
(
buffer
,
buffer_size
,
&
size
);
value
->
resize
(
size
);
value
->
resize
(
size
);
size_t
nbyte
=
value
->
size
()
*
sizeof
(
T
);
size_t
nbyte
=
value
->
size
()
*
sizeof
(
T
);
assert
(
*
buffer_size
>=
nbyte
);
assert
(
*
buffer_size
>=
nbyte
);
::
memcpy
(
value
->
data
(),
*
buffer
,
nbyte
);
std
::
memcpy
(
value
->
data
(),
*
buffer
,
nbyte
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
nbyte
;
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
nbyte
;
*
buffer_size
-=
nbyte
;
*
buffer_size
-=
nbyte
;
}
}
...
@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>,
...
@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>,
}
// namespace
}
// namespace
template
<
typename
T
>
template
<
typename
T
>
inline
size_t
serialized_s
ize
(
T
const
&
value
)
{
inline
size_t
SerializedS
ize
(
T
const
&
value
)
{
return
Serializer
<
T
>::
serialized_s
ize
(
value
);
return
Serializer
<
T
>::
SerializedS
ize
(
value
);
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
serialize_v
alue
(
void
**
buffer
,
T
const
&
value
)
{
inline
void
SerializeV
alue
(
void
**
buffer
,
T
const
&
value
)
{
return
Serializer
<
T
>::
s
erialize
(
buffer
,
value
);
return
Serializer
<
T
>::
S
erialize
(
buffer
,
value
);
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
deserialize_v
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
inline
void
DeserializeV
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
T
*
value
)
{
return
Serializer
<
T
>::
d
eserialize
(
buffer
,
buffer_size
,
value
);
return
Serializer
<
T
>::
D
eserialize
(
buffer
,
buffer_size
,
value
);
}
}
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
浏览文件 @
0b962680
...
@@ -37,7 +37,6 @@ int SplitPlugin::initialize() {
...
@@ -37,7 +37,6 @@ int SplitPlugin::initialize() {
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_length_
[
i
]);
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_length_
[
i
]);
}
}
segment_offsets_
=
segment_offsets
;
segment_offsets_
=
segment_offsets
;
d_segment_offsets_
=
segment_offsets
;
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nx_
=
1
;
nx_
=
1
;
for
(
int
i
=
dims
.
nbDims
-
1
;
i
>
axis_
;
--
i
)
{
for
(
int
i
=
dims
.
nbDims
-
1
;
i
>
axis_
;
--
i
)
{
...
@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
...
@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
int
input_size
=
0
;
int
const
*
d_segment_offsets_ptr
=
thrust
::
raw_pointer_cast
(
&
d_segment_offsets_
[
0
]);
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
浏览文件 @
0b962680
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
#pragma once
#pragma once
#include <thrust/device_vector.h>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT {
...
@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT {
int
axis_
;
int
axis_
;
std
::
vector
<
int
>
output_length_
;
std
::
vector
<
int
>
output_length_
;
int
nx_
,
ny_
,
nz_
;
int
nx_
,
ny_
,
nz_
;
thrust
::
device_vector
<
int
>
d_segment_offsets_
;
std
::
vector
<
int
>
segment_offsets_
;
std
::
vector
<
int
>
segment_offsets_
;
protected:
protected:
virtual
size_t
getSerializationSize
()
override
{
virtual
size_t
getSerializationSize
()
override
{
return
serialized_size
(
axis_
)
+
serialized_s
ize
(
output_length_
)
+
return
SerializedSize
(
axis_
)
+
SerializedS
ize
(
output_length_
)
+
getBaseSerializationSize
();
getBaseSerializationSize
();
}
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual
void
serialize
(
void
*
buffer
)
override
{
virtual
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
serializeBase
(
buffer
);
serialize_v
alue
(
&
buffer
,
axis_
);
SerializeV
alue
(
&
buffer
,
axis_
);
serialize_v
alue
(
&
buffer
,
output_length_
);
SerializeV
alue
(
&
buffer
,
output_length_
);
}
}
public:
public:
...
@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT {
...
@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT {
assert
(
axis
<=
nvinfer1
::
Dims
::
MAX_DIMS
);
assert
(
axis
<=
nvinfer1
::
Dims
::
MAX_DIMS
);
}
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
deserializeBase
(
serialData
,
serialLength
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
axis_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
axis_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
output_length_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
output_length_
);
}
}
SplitPlugin
*
clone
()
const
override
{
SplitPlugin
*
clone
()
const
override
{
...
@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT {
...
@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT {
virtual
int
initialize
()
override
;
virtual
int
initialize
()
override
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
setAxis
(
int
axis
)
{
axis_
=
axis
;
}
void
setOutputLengths
(
const
std
::
vector
<
int
>
&
output_lengths
)
{
output_length_
=
output_lengths
;
}
};
};
}
// tensorrt
}
// tensorrt
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
浏览文件 @
0b962680
...
@@ -19,23 +19,23 @@ namespace inference {
...
@@ -19,23 +19,23 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
serialize_v
alue
(
&
buffer
,
input_dims_
);
SerializeV
alue
(
&
buffer
,
input_dims_
);
serialize_v
alue
(
&
buffer
,
max_batch_size_
);
SerializeV
alue
(
&
buffer
,
max_batch_size_
);
serialize_v
alue
(
&
buffer
,
data_type_
);
SerializeV
alue
(
&
buffer
,
data_type_
);
serialize_v
alue
(
&
buffer
,
data_format_
);
SerializeV
alue
(
&
buffer
,
data_format_
);
}
}
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serialData
,
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
)
{
size_t
&
serialLength
)
{
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
input_dims_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
input_dims_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
max_batch_size_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
max_batch_size_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
data_type_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
data_type_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
data_format_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
data_format_
);
}
}
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
return
(
serialized_size
(
input_dims_
)
+
serialized_s
ize
(
max_batch_size_
)
+
return
(
SerializedSize
(
input_dims_
)
+
SerializedS
ize
(
max_batch_size_
)
+
serialized_size
(
data_type_
)
+
serialized_s
ize
(
data_format_
));
SerializedSize
(
data_type_
)
+
SerializedS
ize
(
data_format_
));
}
}
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
浏览文件 @
0b962680
...
@@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
...
@@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
size_t
getWorkspaceSize
(
int
)
const
override
{
return
0
;
}
size_t
getWorkspaceSize
(
int
)
const
override
{
return
0
;
}
void
terminate
()
override
{}
void
terminate
()
override
{}
virtual
~
PluginTensorRT
()
{}
virtual
~
PluginTensorRT
()
{}
// Check format support. The default is FLOAT32 and NCHW.
// The following functions need to be overrided in the subclass.
virtual
nvinfer1
::
IPluginExt
*
clone
()
const
=
0
;
virtual
const
char
*
getPluginType
()
const
=
0
;
int
initialize
()
override
{
return
0
;
}
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
override
;
nvinfer1
::
PluginFormat
format
)
const
override
;
void
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
void
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
...
@@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
...
@@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1
::
DataType
type
,
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
override
;
int
maxBatchSize
)
override
;
// *NOTE* The following functions need to be overrided in the subclass.
virtual
nvinfer1
::
IPluginExt
*
clone
()
const
=
0
;
virtual
const
char
*
getPluginType
()
const
=
0
;
// Initialize the layer for execution. This is called when the engine is
// created.
int
initialize
()
override
{
return
0
;
}
// Serialize the layer config to buffer.
virtual
void
serialize
(
void
*
buffer
)
=
0
;
virtual
void
serialize
(
void
*
buffer
)
=
0
;
virtual
size_t
getSerializationSize
()
=
0
;
virtual
size_t
getSerializationSize
()
=
0
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
protected:
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
size_t
getBaseSerializationSize
();
size_t
getBaseSerializationSize
();
// Serialize input_dims, max_batch_size, data_type, data_format
void
serializeBase
(
void
*&
buffer
);
void
serializeBase
(
void
*&
buffer
);
std
::
vector
<
nvinfer1
::
Dims
>
input_dims_
;
std
::
vector
<
nvinfer1
::
Dims
>
input_dims_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录