Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
77ac30e5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
77ac30e5
编写于
11月 14, 2018
作者:
Z
Zhaolong Xing
提交者:
GitHub
11月 14, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14386 from NHZlX/add_trt_plugin
add plugin support for paddle-trt
上级
8cfda7ee
15bdb7ef
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
557 addition
and
5 deletion
+557
-5
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc
...uid/inference/analysis/passes/ir_analysis_compose_pass.cc
+1
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-0
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+6
-3
paddle/fluid/inference/tensorrt/convert/concat_op.cc
paddle/fluid/inference/tensorrt/convert/concat_op.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+75
-0
paddle/fluid/inference/tensorrt/convert/test_split_op.cc
paddle/fluid/inference/tensorrt/convert/test_split_op.cc
+53
-0
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+6
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+5
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/serialize.h
paddle/fluid/inference/tensorrt/plugin/serialize.h
+111
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+81
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+74
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+61
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+80
-0
未找到文件。
paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc
浏览文件 @
77ac30e5
...
...
@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std
::
unordered_set
<
std
::
string
>
teller_set
(
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"dropout"
});
"elementwise_add"
,
"dropout"
,
"split"
});
if
(
!
node
->
IsOp
())
return
false
;
if
(
teller_set
.
count
(
node
->
Op
()
->
Type
()))
{
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
77ac30e5
...
...
@@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER
(
concat
);
USE_TRT_CONVERTER
(
dropout
);
USE_TRT_CONVERTER
(
pad
);
USE_TRT_CONVERTER
(
split
);
#endif
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
77ac30e5
nv_library
(
tensorrt_engine SRCS engine.cc DEPS framework_proto device_context
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
add_subdirectory
(
plugin
)
add_subdirectory
(
convert
)
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
77ac30e5
# Add TRT tests
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry
)
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_converter
)
...
...
@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine concat_op SERIAL
)
nv_test
(
test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine dropout_op SERIAL
)
nv_test
(
test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine pad_op SERIAL
)
nv_test
(
test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL
)
paddle/fluid/inference/tensorrt/convert/concat_op.cc
浏览文件 @
77ac30e5
...
...
@@ -19,7 +19,7 @@ namespace inference {
namespace
tensorrt
{
/*
*
MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*
ConcatOp
*/
class
ConcatOpConverter
:
public
OpConverter
{
public:
...
...
paddle/fluid/inference/tensorrt/convert/split_op.cc
0 → 100644
浏览文件 @
77ac30e5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* SplitOp.
*/
class
SplitOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
40
)
<<
"convert a fluid split op to tensorrt split layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
input_dims
=
input
->
getDimensions
();
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
// Get Attrs
PADDLE_ENFORCE
(
input_num
==
1
);
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
std
::
vector
<
int
>
output_lengths
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"sections"
));
PADDLE_ENFORCE
(
axis
!=
0
);
if
(
axis
<
0
)
{
axis
+=
input_dims
.
nbDims
;
}
else
{
axis
-=
1
;
}
PADDLE_ENFORCE
(
output_lengths
.
size
()
==
output_num
);
//
SplitPlugin
*
plugin
=
new
SplitPlugin
(
axis
,
output_lengths
);
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
std
::
string
layer_name
=
"split (Output: "
;
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
i
];
layer
->
getOutput
(
i
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
i
));
layer_name
+=
output_name
;
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
}
layer
->
setName
((
layer_name
+
")"
).
c_str
());
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
split
,
SplitOpConverter
);
paddle/fluid/inference/tensorrt/convert/test_split_op.cc
0 → 100644
浏览文件 @
77ac30e5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
TEST
(
split_op
,
test
)
{
std
::
unordered_set
<
std
::
string
>
parameters
({
""
});
framework
::
Scope
scope
;
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
validator
.
DeclInputVar
(
"split_input"
,
nvinfer1
::
DimsCHW
(
3
,
2
,
2
));
validator
.
DeclOutputVar
(
"split_out1"
,
nvinfer1
::
DimsCHW
(
2
,
2
,
2
));
validator
.
DeclOutputVar
(
"split_out2"
,
nvinfer1
::
DimsCHW
(
1
,
2
,
2
));
// Prepare Op description
framework
::
OpDesc
desc
;
desc
.
SetType
(
"split"
);
desc
.
SetInput
(
"X"
,
{
"split_input"
});
desc
.
SetOutput
(
"Out"
,
{
"split_out1"
,
"split_out2"
});
int
num
=
0
;
int
axis
=
1
;
std
::
vector
<
int
>
output_lengths
=
{
2
,
1
};
desc
.
SetAttr
(
"axis"
,
axis
);
desc
.
SetAttr
(
"num"
,
num
);
desc
.
SetAttr
(
"sections"
,
output_lengths
);
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
Execute
(
1
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
split
);
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
77ac30e5
...
...
@@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice
(
device_
);
}
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
plugin
)
{
owned_plugin_
.
emplace_back
(
plugin
);
return
infer_network_
.
get
()
->
addPluginExt
(
inputs
,
nbInputs
,
*
plugin
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
77ac30e5
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
...
...
@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
nvinfer1
::
IPluginLayer
*
AddPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
);
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
...
...
@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
std
::
vector
<
std
::
unique_ptr
<
PluginTensorRT
>>
owned_plugin_
;
// TensorRT related internal members
template
<
typename
T
>
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
0 → 100644
浏览文件 @
77ac30e5
nv_library
(
tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce
)
paddle/fluid/inference/tensorrt/plugin/serialize.h
0 → 100644
浏览文件 @
77ac30e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <type_traits>
#include <vector>
template
<
typename
T
>
inline
void
SerializeValue
(
void
**
buffer
,
T
const
&
value
);
template
<
typename
T
>
inline
void
DeserializeValue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
);
namespace
{
template
<
typename
T
,
class
Enable
=
void
>
struct
Serializer
{};
template
<
typename
T
>
struct
Serializer
<
T
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
SerializedSize
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
void
Serialize
(
void
**
buffer
,
T
const
&
value
)
{
std
::
memcpy
(
*
buffer
,
&
value
,
sizeof
(
T
));
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
sizeof
(
T
);
}
static
void
Deserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
assert
(
*
buffer_size
>=
sizeof
(
T
));
std
::
memcpy
(
value
,
*
buffer
,
sizeof
(
T
));
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
sizeof
(
T
);
*
buffer_size
-=
sizeof
(
T
);
}
};
template
<
>
struct
Serializer
<
const
char
*>
{
static
size_t
SerializedSize
(
const
char
*
value
)
{
return
strlen
(
value
)
+
1
;
}
static
void
Serialize
(
void
**
buffer
,
const
char
*
value
)
{
std
::
strcpy
(
static_cast
<
char
*>
(
*
buffer
),
value
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
strlen
(
value
)
+
1
;
}
static
void
Deserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
const
char
**
value
)
{
*
value
=
static_cast
<
char
const
*>
(
*
buffer
);
size_t
data_size
=
strnlen
(
*
value
,
*
buffer_size
)
+
1
;
assert
(
*
buffer_size
>=
data_size
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
data_size
;
*
buffer_size
-=
data_size
;
}
};
template
<
typename
T
>
struct
Serializer
<
std
::
vector
<
T
>
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
SerializedSize
(
std
::
vector
<
T
>
const
&
value
)
{
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
}
static
void
Serialize
(
void
**
buffer
,
std
::
vector
<
T
>
const
&
value
)
{
SerializeValue
(
buffer
,
value
.
size
());
size_t
nbyte
=
value
.
size
()
*
sizeof
(
T
);
std
::
memcpy
(
*
buffer
,
value
.
data
(),
nbyte
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
nbyte
;
}
static
void
Deserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
std
::
vector
<
T
>*
value
)
{
size_t
size
;
DeserializeValue
(
buffer
,
buffer_size
,
&
size
);
value
->
resize
(
size
);
size_t
nbyte
=
value
->
size
()
*
sizeof
(
T
);
assert
(
*
buffer_size
>=
nbyte
);
std
::
memcpy
(
value
->
data
(),
*
buffer
,
nbyte
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
nbyte
;
*
buffer_size
-=
nbyte
;
}
};
}
// namespace
template
<
typename
T
>
inline
size_t
SerializedSize
(
T
const
&
value
)
{
return
Serializer
<
T
>::
SerializedSize
(
value
);
}
template
<
typename
T
>
inline
void
SerializeValue
(
void
**
buffer
,
T
const
&
value
)
{
return
Serializer
<
T
>::
Serialize
(
buffer
,
value
);
}
template
<
typename
T
>
inline
void
DeserializeValue
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
return
Serializer
<
T
>::
Deserialize
(
buffer
,
buffer_size
,
value
);
}
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
0 → 100644
浏览文件 @
77ac30e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
nvinfer1
::
Dims
SplitPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
assert
(
nbInputs
==
1
);
assert
(
index
<
this
->
getNbOutputs
());
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
output_dims
.
d
[
axis_
]
=
output_length_
.
at
(
index
);
return
output_dims
;
}
int
SplitPlugin
::
initialize
()
{
std
::
vector
<
int
>
segment_offsets
(
1
,
0
);
for
(
int
i
=
0
;
i
<
this
->
getNbOutputs
();
++
i
)
{
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_length_
[
i
]);
}
segment_offsets_
=
segment_offsets
;
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nx_
=
1
;
for
(
int
i
=
dims
.
nbDims
-
1
;
i
>
axis_
;
--
i
)
{
nx_
*=
dims
.
d
[
i
];
}
ny_
=
dims
.
d
[
axis_
];
nz_
=
1
;
for
(
int
i
=
axis_
-
1
;
i
>=
0
;
--
i
)
{
nz_
*=
dims
.
d
[
i
];
}
return
0
;
}
int
SplitPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
// kernel impl here.
int
inputBatchOffset
=
nx_
*
ny_
*
nz_
;
for
(
size_t
i
=
0
;
i
<
this
->
getNbOutputs
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
batchSize
;
j
++
)
{
cudaMemcpyAsync
(
odatas
[
i
]
+
j
*
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
nx_
*
sizeof
(
float
),
inputs
[
0
]
+
(
inputBatchOffset
*
j
+
segment_offsets_
[
i
]
*
nx_
)
*
sizeof
(
float
),
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
nx_
*
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
);
}
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// tensorrt
}
// inference
}
// paddle
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
0 → 100644
浏览文件 @
77ac30e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
SplitPlugin
:
public
PluginTensorRT
{
int
axis_
;
std
::
vector
<
int
>
output_length_
;
int
nx_
,
ny_
,
nz_
;
std
::
vector
<
int
>
segment_offsets_
;
protected:
virtual
size_t
getSerializationSize
()
override
{
return
SerializedSize
(
axis_
)
+
SerializedSize
(
output_length_
)
+
getBaseSerializationSize
();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
axis_
);
SerializeValue
(
&
buffer
,
output_length_
);
}
public:
SplitPlugin
(
int
axis
,
std
::
vector
<
int
>
const
&
output_lengths
)
:
axis_
(
axis
),
output_length_
(
output_lengths
)
{
assert
(
axis
<=
nvinfer1
::
Dims
::
MAX_DIMS
);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
axis_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
output_length_
);
}
SplitPlugin
*
clone
()
const
override
{
return
new
SplitPlugin
(
axis_
,
output_length_
);
}
virtual
const
char
*
getPluginType
()
const
override
{
return
"split"
;
}
virtual
int
getNbOutputs
()
const
override
{
return
output_length_
.
size
();
}
virtual
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
virtual
int
initialize
()
override
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
};
}
// tensorrt
}
// inference
}
// paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
0 → 100644
浏览文件 @
77ac30e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
SerializeValue
(
&
buffer
,
input_dims_
);
SerializeValue
(
&
buffer
,
max_batch_size_
);
SerializeValue
(
&
buffer
,
data_type_
);
SerializeValue
(
&
buffer
,
data_format_
);
}
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
input_dims_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
max_batch_size_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
data_type_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
data_format_
);
}
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
return
(
SerializedSize
(
input_dims_
)
+
SerializedSize
(
max_batch_size_
)
+
SerializedSize
(
data_type_
)
+
SerializedSize
(
data_format_
));
}
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kNCHW
));
}
void
PluginTensorRT
::
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
const
nvinfer1
::
Dims
*
outputDims
,
int
nbOutputs
,
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
{
data_type_
=
type
;
data_format_
=
format
;
input_dims_
.
assign
(
inputDims
,
inputDims
+
nbInputs
);
max_batch_size_
=
maxBatchSize
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
0 → 100644
浏览文件 @
77ac30e5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <iostream>
#include <unordered_map>
#include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
PluginTensorRT
:
public
nvinfer1
::
IPluginExt
{
public:
PluginTensorRT
()
{}
PluginTensorRT
(
const
void
*
serialized_data
,
size_t
length
)
{}
nvinfer1
::
Dims
const
&
getInputDims
(
int
index
)
const
{
return
input_dims_
.
at
(
index
);
}
size_t
getMaxBatchSize
()
const
{
return
max_batch_size_
;
}
nvinfer1
::
DataType
getDataType
()
const
{
return
data_type_
;
}
nvinfer1
::
PluginFormat
getDataFormat
()
const
{
return
data_format_
;
}
virtual
const
char
*
getPluginVersion
()
const
{
return
"1"
;
}
size_t
getWorkspaceSize
(
int
)
const
override
{
return
0
;
}
void
terminate
()
override
{}
virtual
~
PluginTensorRT
()
{}
// Check format support. The default is FLOAT32 and NCHW.
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
override
;
void
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
const
nvinfer1
::
Dims
*
outputDims
,
int
nbOutputs
,
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
override
;
// *NOTE* The following functions need to be overrided in the subclass.
virtual
nvinfer1
::
IPluginExt
*
clone
()
const
=
0
;
virtual
const
char
*
getPluginType
()
const
=
0
;
// Initialize the layer for execution. This is called when the engine is
// created.
int
initialize
()
override
{
return
0
;
}
// Serialize the layer config to buffer.
virtual
void
serialize
(
void
*
buffer
)
=
0
;
virtual
size_t
getSerializationSize
()
=
0
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
size_t
getBaseSerializationSize
();
// Serialize input_dims, max_batch_size, data_type, data_format
void
serializeBase
(
void
*&
buffer
);
std
::
vector
<
nvinfer1
::
Dims
>
input_dims_
;
size_t
max_batch_size_
;
nvinfer1
::
DataType
data_type_
;
nvinfer1
::
PluginFormat
data_format_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录