Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0b962680
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0b962680
编写于
11月 13, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix comments
test=develop
上级
e5bf8616
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
64 addition
and
60 deletion
+64
-60
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+3
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+1
-1
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+1
-1
paddle/fluid/inference/tensorrt/plugin/serialize.h
paddle/fluid/inference/tensorrt/plugin/serialize.h
+26
-26
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+0
-3
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+10
-13
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+10
-10
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+13
-5
未找到文件。
paddle/fluid/inference/tensorrt/convert/split_op.cc
浏览文件 @
0b962680
...
...
@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter {
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
// Get Attrs
PADDLE_ENFORCE
(
input_num
==
1
);
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
std
::
vector
<
int
>
output_lengths
=
...
...
@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter {
PADDLE_ENFORCE
(
output_lengths
.
size
()
==
output_num
);
//
SplitPlugin
*
plugin
=
new
SplitPlugin
(
axis
,
output_lengths
);
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
a
ddPlugin
(
&
input
,
input_num
,
plugin
);
engine_
->
A
ddPlugin
(
&
input
,
input_num
,
plugin
);
std
::
string
layer_name
=
"split (Output: "
;
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
0b962680
...
...
@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice
(
device_
);
}
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
a
ddPlugin
(
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
A
ddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
plugin
)
{
owned_plugin_
.
emplace_back
(
plugin
);
return
infer_network_
.
get
()
->
addPluginExt
(
inputs
,
nbInputs
,
*
plugin
);
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
0b962680
...
...
@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase {
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
nvinfer1
::
IPluginLayer
*
a
ddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
nvinfer1
::
IPluginLayer
*
A
ddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
);
// A pointer to CPU memory is needed of the TRT weight.
...
...
paddle/fluid/inference/tensorrt/plugin/serialize.h
浏览文件 @
0b962680
...
...
@@ -20,11 +20,11 @@
#include <vector>
template
<
typename
T
>
inline
void
serialize_v
alue
(
void
**
buffer
,
T
const
&
value
);
inline
void
SerializeV
alue
(
void
**
buffer
,
T
const
&
value
);
template
<
typename
T
>
inline
void
deserialize_v
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
);
inline
void
DeserializeV
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
);
namespace
{
...
...
@@ -35,14 +35,14 @@ template <typename T>
struct
Serializer
<
T
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
serialized_s
ize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
void
s
erialize
(
void
**
buffer
,
T
const
&
value
)
{
::
memcpy
(
*
buffer
,
&
value
,
sizeof
(
T
));
static
size_t
SerializedS
ize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
void
S
erialize
(
void
**
buffer
,
T
const
&
value
)
{
std
::
memcpy
(
*
buffer
,
&
value
,
sizeof
(
T
));
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
sizeof
(
T
);
}
static
void
d
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
static
void
D
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
assert
(
*
buffer_size
>=
sizeof
(
T
));
::
memcpy
(
value
,
*
buffer
,
sizeof
(
T
));
std
::
memcpy
(
value
,
*
buffer
,
sizeof
(
T
));
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
sizeof
(
T
);
*
buffer_size
-=
sizeof
(
T
);
}
...
...
@@ -50,12 +50,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
template
<
>
struct
Serializer
<
const
char
*>
{
static
size_t
serialized_s
ize
(
const
char
*
value
)
{
return
strlen
(
value
)
+
1
;
}
static
void
s
erialize
(
void
**
buffer
,
const
char
*
value
)
{
::
strcpy
(
static_cast
<
char
*>
(
*
buffer
),
value
);
static
size_t
SerializedS
ize
(
const
char
*
value
)
{
return
strlen
(
value
)
+
1
;
}
static
void
S
erialize
(
void
**
buffer
,
const
char
*
value
)
{
std
::
strcpy
(
static_cast
<
char
*>
(
*
buffer
),
value
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
strlen
(
value
)
+
1
;
}
static
void
d
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
static
void
D
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
const
char
**
value
)
{
*
value
=
static_cast
<
char
const
*>
(
*
buffer
);
size_t
data_size
=
strnlen
(
*
value
,
*
buffer_size
)
+
1
;
...
...
@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
serialized_s
ize
(
std
::
vector
<
T
>
const
&
value
)
{
static
size_t
SerializedS
ize
(
std
::
vector
<
T
>
const
&
value
)
{
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
}
static
void
s
erialize
(
void
**
buffer
,
std
::
vector
<
T
>
const
&
value
)
{
serialize_v
alue
(
buffer
,
value
.
size
());
static
void
S
erialize
(
void
**
buffer
,
std
::
vector
<
T
>
const
&
value
)
{
SerializeV
alue
(
buffer
,
value
.
size
());
size_t
nbyte
=
value
.
size
()
*
sizeof
(
T
);
::
memcpy
(
*
buffer
,
value
.
data
(),
nbyte
);
std
::
memcpy
(
*
buffer
,
value
.
data
(),
nbyte
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
nbyte
;
}
static
void
d
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
static
void
D
eserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
std
::
vector
<
T
>*
value
)
{
size_t
size
;
deserialize_v
alue
(
buffer
,
buffer_size
,
&
size
);
DeserializeV
alue
(
buffer
,
buffer_size
,
&
size
);
value
->
resize
(
size
);
size_t
nbyte
=
value
->
size
()
*
sizeof
(
T
);
assert
(
*
buffer_size
>=
nbyte
);
::
memcpy
(
value
->
data
(),
*
buffer
,
nbyte
);
std
::
memcpy
(
value
->
data
(),
*
buffer
,
nbyte
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
nbyte
;
*
buffer_size
-=
nbyte
;
}
...
...
@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>,
}
// namespace
template
<
typename
T
>
inline
size_t
serialized_s
ize
(
T
const
&
value
)
{
return
Serializer
<
T
>::
serialized_s
ize
(
value
);
inline
size_t
SerializedS
ize
(
T
const
&
value
)
{
return
Serializer
<
T
>::
SerializedS
ize
(
value
);
}
template
<
typename
T
>
inline
void
serialize_v
alue
(
void
**
buffer
,
T
const
&
value
)
{
return
Serializer
<
T
>::
s
erialize
(
buffer
,
value
);
inline
void
SerializeV
alue
(
void
**
buffer
,
T
const
&
value
)
{
return
Serializer
<
T
>::
S
erialize
(
buffer
,
value
);
}
template
<
typename
T
>
inline
void
deserialize_v
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
return
Serializer
<
T
>::
d
eserialize
(
buffer
,
buffer_size
,
value
);
inline
void
DeserializeV
alue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
return
Serializer
<
T
>::
D
eserialize
(
buffer
,
buffer_size
,
value
);
}
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
浏览文件 @
0b962680
...
...
@@ -37,7 +37,6 @@ int SplitPlugin::initialize() {
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_length_
[
i
]);
}
segment_offsets_
=
segment_offsets
;
d_segment_offsets_
=
segment_offsets
;
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nx_
=
1
;
for
(
int
i
=
dims
.
nbDims
-
1
;
i
>
axis_
;
--
i
)
{
...
...
@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
int
const
*
d_segment_offsets_ptr
=
thrust
::
raw_pointer_cast
(
&
d_segment_offsets_
[
0
]);
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
浏览文件 @
0b962680
...
...
@@ -14,7 +14,6 @@
#pragma once
#include <thrust/device_vector.h>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
...
...
@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT {
int
axis_
;
std
::
vector
<
int
>
output_length_
;
int
nx_
,
ny_
,
nz_
;
thrust
::
device_vector
<
int
>
d_segment_offsets_
;
std
::
vector
<
int
>
segment_offsets_
;
protected:
virtual
size_t
getSerializationSize
()
override
{
return
serialized_size
(
axis_
)
+
serialized_s
ize
(
output_length_
)
+
return
SerializedSize
(
axis_
)
+
SerializedS
ize
(
output_length_
)
+
getBaseSerializationSize
();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
serialize_v
alue
(
&
buffer
,
axis_
);
serialize_v
alue
(
&
buffer
,
output_length_
);
SerializeV
alue
(
&
buffer
,
axis_
);
SerializeV
alue
(
&
buffer
,
output_length_
);
}
public:
...
...
@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT {
assert
(
axis
<=
nvinfer1
::
Dims
::
MAX_DIMS
);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
axis_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
output_length_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
axis_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
output_length_
);
}
SplitPlugin
*
clone
()
const
override
{
...
...
@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT {
virtual
int
initialize
()
override
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
setAxis
(
int
axis
)
{
axis_
=
axis
;
}
void
setOutputLengths
(
const
std
::
vector
<
int
>
&
output_lengths
)
{
output_length_
=
output_lengths
;
}
};
}
// tensorrt
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
浏览文件 @
0b962680
...
...
@@ -19,23 +19,23 @@ namespace inference {
namespace
tensorrt
{
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
serialize_v
alue
(
&
buffer
,
input_dims_
);
serialize_v
alue
(
&
buffer
,
max_batch_size_
);
serialize_v
alue
(
&
buffer
,
data_type_
);
serialize_v
alue
(
&
buffer
,
data_format_
);
SerializeV
alue
(
&
buffer
,
input_dims_
);
SerializeV
alue
(
&
buffer
,
max_batch_size_
);
SerializeV
alue
(
&
buffer
,
data_type_
);
SerializeV
alue
(
&
buffer
,
data_format_
);
}
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
)
{
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
input_dims_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
max_batch_size_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
data_type_
);
deserialize_v
alue
(
&
serialData
,
&
serialLength
,
&
data_format_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
input_dims_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
max_batch_size_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
data_type_
);
DeserializeV
alue
(
&
serialData
,
&
serialLength
,
&
data_format_
);
}
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
return
(
serialized_size
(
input_dims_
)
+
serialized_s
ize
(
max_batch_size_
)
+
serialized_size
(
data_type_
)
+
serialized_s
ize
(
data_format_
));
return
(
SerializedSize
(
input_dims_
)
+
SerializedS
ize
(
max_batch_size_
)
+
SerializedSize
(
data_type_
)
+
SerializedS
ize
(
data_format_
));
}
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
浏览文件 @
0b962680
...
...
@@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
size_t
getWorkspaceSize
(
int
)
const
override
{
return
0
;
}
void
terminate
()
override
{}
virtual
~
PluginTensorRT
()
{}
// The following functions need to be overrided in the subclass.
virtual
nvinfer1
::
IPluginExt
*
clone
()
const
=
0
;
virtual
const
char
*
getPluginType
()
const
=
0
;
int
initialize
()
override
{
return
0
;
}
// Check format support. The default is FLOAT32 and NCHW.
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
override
;
void
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
...
...
@@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
override
;
// *NOTE* The following functions need to be overrided in the subclass.
virtual
nvinfer1
::
IPluginExt
*
clone
()
const
=
0
;
virtual
const
char
*
getPluginType
()
const
=
0
;
// Initialize the layer for execution. This is called when the engine is
// created.
int
initialize
()
override
{
return
0
;
}
// Serialize the layer config to buffer.
virtual
void
serialize
(
void
*
buffer
)
=
0
;
virtual
size_t
getSerializationSize
()
=
0
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
size_t
getBaseSerializationSize
();
// Serialize input_dims, max_batch_size, data_type, data_format
void
serializeBase
(
void
*&
buffer
);
std
::
vector
<
nvinfer1
::
Dims
>
input_dims_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录