Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d38fd6a0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d38fd6a0
编写于
11月 13, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add plugin support and offer an simple split sample
上级
2d7134bc
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
208 addition
and
210 deletion
+208
-210
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+1
-1
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+5
-2
paddle/fluid/inference/tensorrt/convert/split_op.cc
paddle/fluid/inference/tensorrt/convert/split_op.cc
+73
-0
paddle/fluid/inference/tensorrt/convert/test_split_op.cc
paddle/fluid/inference/tensorrt/convert/test_split_op.cc
+53
-0
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+6
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+5
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-2
paddle/fluid/inference/tensorrt/plugin/plugin_factory.h
paddle/fluid/inference/tensorrt/plugin/plugin_factory.h
+0
-91
paddle/fluid/inference/tensorrt/plugin/plugin_utils.h
paddle/fluid/inference/tensorrt/plugin/plugin_utils.h
+0
-34
paddle/fluid/inference/tensorrt/plugin/serialize.h
paddle/fluid/inference/tensorrt/plugin/serialize.h
+0
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+20
-50
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+38
-23
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+1
-3
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+4
-4
未找到文件。
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
d38fd6a0
...
@@ -71,7 +71,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -71,7 +71,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
std
::
unordered_set
<
std
::
string
>
teller_set
(
std
::
unordered_set
<
std
::
string
>
teller_set
(
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
{
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"dropout"
});
"elementwise_add"
,
"dropout"
,
"split"
});
if
(
!
node
->
IsFunction
())
return
false
;
if
(
!
node
->
IsFunction
())
return
false
;
const
auto
*
func
=
static_cast
<
const
Function
*>
(
node
);
const
auto
*
func
=
static_cast
<
const
Function
*>
(
node
);
...
...
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
浏览文件 @
d38fd6a0
...
@@ -186,3 +186,4 @@ USE_TRT_CONVERTER(batch_norm);
...
@@ -186,3 +186,4 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER
(
concat
);
USE_TRT_CONVERTER
(
concat
);
USE_TRT_CONVERTER
(
dropout
);
USE_TRT_CONVERTER
(
dropout
);
USE_TRT_CONVERTER
(
pad
);
USE_TRT_CONVERTER
(
pad
);
USE_TRT_CONVERTER
(
split
);
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
d38fd6a0
# Add TRT tests
# Add TRT tests
nv_library
(
tensorrt_converter
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
...
@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine concat_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine concat_op SERIAL
)
nv_test
(
test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
nv_test
(
test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine dropout_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine dropout_op SERIAL
)
nv_test
(
test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
nv_test
(
test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine pad_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine pad_op SERIAL
)
nv_test
(
test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL
)
paddle/fluid/inference/tensorrt/
plugin/plugin_factory
.cc
→
paddle/fluid/inference/tensorrt/
convert/split_op
.cc
浏览文件 @
d38fd6a0
...
@@ -12,53 +12,62 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,53 +12,62 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
PluginTensorRT
*
PluginFactoryTensorRT
::
createPlugin
(
const
char
*
layer_name
,
/*
const
void
*
serial_data
,
* SplitOp.
size_t
serial_length
)
{
*/
size_t
parsed_byte
=
0
;
class
SplitOpConverter
:
public
OpConverter
{
std
::
string
encoded_op_name
=
public:
ExtractOpName
(
serial_data
,
serial_length
,
&
parsed_byte
);
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
if
(
!
IsPlugin
(
encoded_op_name
))
{
VLOG
(
40
)
<<
"convert a fluid split op to tensorrt split layer"
;
return
nullptr
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
input_dims
=
input
->
getDimensions
();
int
input_num
=
op_desc
.
Input
(
"X"
).
size
();
size_t
output_num
=
op_desc
.
Output
(
"Out"
).
size
();
PADDLE_ENFORCE
(
input_num
==
1
);
int
axis
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"axis"
));
std
::
vector
<
int
>
output_lengths
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"sections"
));
PADDLE_ENFORCE
(
axis
!=
0
);
if
(
axis
<
0
)
{
axis
+=
input_dims
.
nbDims
;
}
else
{
axis
-=
1
;
}
PADDLE_ENFORCE
(
output_lengths
.
size
()
==
output_num
);
SplitPlugin
*
plugin
=
new
SplitPlugin
(
axis
,
output_lengths
);
nvinfer1
::
IPluginLayer
*
layer
=
engine_
->
addPlugin
(
&
input
,
input_num
,
plugin
);
std
::
string
layer_name
=
"split (Output: "
;
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
i
];
layer
->
getOutput
(
i
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
i
));
layer_name
+=
output_name
;
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
}
layer
->
setName
((
layer_name
+
")"
).
c_str
());
}
}
};
auto
plugin_ptr
=
plugin_registry_
[
encoded_op_name
].
first
(
serial_data
,
serial_length
);
owned_plugins_
.
emplace_back
(
plugin_ptr
);
return
plugin_ptr
;
}
PluginTensorRT
*
PluginFactoryTensorRT
::
CreatePlugin
(
const
std
::
string
&
op_name
)
{
if
(
!
IsPlugin
(
op_name
))
return
nullptr
;
auto
plugin_ptr
=
plugin_registry_
[
op_name
].
second
();
owned_plugins_
.
emplace_back
(
plugin_ptr
);
return
plugin_ptr
;
}
bool
PluginFactoryTensorRT
::
RegisterPlugin
(
const
std
::
string
&
op_name
,
PluginDeserializeFunc
deserialize_func
,
PluginConstructFunc
construct_func
)
{
if
(
IsPlugin
(
op_name
))
return
false
;
auto
ret
=
plugin_registry_
.
emplace
(
op_name
,
std
::
make_pair
(
deserialize_func
,
construct_func
));
return
ret
.
second
;
}
void
PluginFactoryTensorRT
::
DestroyPlugins
()
{
owned_plugins_
.
clear
();
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
split
,
SplitOpConverter
);
paddle/fluid/inference/tensorrt/
plugin/plugin_utils
.cc
→
paddle/fluid/inference/tensorrt/
convert/test_split_op
.cc
浏览文件 @
d38fd6a0
...
@@ -12,26 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,26 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
#include <gtest/gtest.h>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
std
::
string
ExtractOpName
(
const
void
*
serial_data
,
size_t
serial_length
,
TEST
(
split_op
,
test
)
{
size_t
*
incremental
)
{
std
::
unordered_set
<
std
::
string
>
parameters
({
""
});
size_t
op_name_char_count
=
*
static_cast
<
const
size_t
*>
(
serial_data
);
framework
::
Scope
scope
;
*
incremental
=
sizeof
(
size_t
)
+
op_name_char_count
;
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
validator
.
DeclInputVar
(
"split_input"
,
nvinfer1
::
DimsCHW
(
3
,
2
,
2
));
assert
(
serial_length
>=
*
incremental
);
validator
.
DeclOutputVar
(
"split_out1"
,
nvinfer1
::
DimsCHW
(
2
,
2
,
2
));
validator
.
DeclOutputVar
(
"split_out2"
,
nvinfer1
::
DimsCHW
(
1
,
2
,
2
));
const
char
*
buffer
=
static_cast
<
const
char
*>
(
serial_data
)
+
sizeof
(
size_t
);
std
::
string
op_name
(
buffer
,
op_name_char_count
);
// Prepare Op description
framework
::
OpDesc
desc
;
return
op_name
;
desc
.
SetType
(
"split"
);
desc
.
SetInput
(
"X"
,
{
"split_input"
});
desc
.
SetOutput
(
"Out"
,
{
"split_out1"
,
"split_out2"
});
int
num
=
0
;
int
axis
=
1
;
std
::
vector
<
int
>
output_lengths
=
{
2
,
1
};
desc
.
SetAttr
(
"axis"
,
axis
);
desc
.
SetAttr
(
"num"
,
num
);
desc
.
SetAttr
(
"sections"
,
output_lengths
);
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
Execute
(
1
);
}
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
USE_OP
(
split
);
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
d38fd6a0
...
@@ -254,6 +254,12 @@ void TensorRTEngine::freshDeviceId() {
...
@@ -254,6 +254,12 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice
(
device_
);
cudaSetDevice
(
device_
);
}
}
nvinfer1
::
IPluginLayer
*
TensorRTEngine
::
addPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
plugin
)
{
owned_plugin_
.
emplace_back
(
plugin
);
return
infer_network_
.
get
()
->
addPluginExt
(
inputs
,
nbInputs
,
*
plugin
);
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
d38fd6a0
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
...
@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
void
SetRuntimeBatch
(
size_t
batch_size
);
void
SetRuntimeBatch
(
size_t
batch_size
);
int
GetRuntimeBatch
();
int
GetRuntimeBatch
();
int
GetDevice
()
{
return
device_
;
}
int
GetDevice
()
{
return
device_
;
}
nvinfer1
::
IPluginLayer
*
addPlugin
(
nvinfer1
::
ITensor
*
const
*
inputs
,
int
nbInputs
,
PluginTensorRT
*
);
// A pointer to CPU memory is needed of the TRT weight.
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
// Before TRT runs, fluid loads weight into GPU storage.
...
@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
...
@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
std
::
unordered_map
<
std
::
string
/*name*/
,
nvinfer1
::
ITensor
*
/*ITensor*/
>
itensor_map_
;
itensor_map_
;
// The specific GPU id that the TensorRTEngine bounded to.
// The specific GPU id that the TensorRTEngine bounded to.
int
device_
;
int
device_
;
std
::
vector
<
std
::
unique_ptr
<
PluginTensorRT
>>
owned_plugin_
;
// TensorRT related internal members
// TensorRT related internal members
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
d38fd6a0
nv_library
(
tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc
nv_library
(
tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce
)
trt_plugin.cc split_op_plugin.cu DEPS enforce
)
paddle/fluid/inference/tensorrt/plugin/plugin_factory.h
已删除
100644 → 0
浏览文件 @
2d7134bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <unordered_map>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
PluginFactoryTensorRT
:
public
nvinfer1
::
IPluginFactory
{
public:
static
PluginFactoryTensorRT
*
GetInstance
()
{
static
PluginFactoryTensorRT
*
factory_instance
=
new
PluginFactoryTensorRT
();
return
factory_instance
;
}
// Deserialization method
PluginTensorRT
*
createPlugin
(
const
char
*
layer_name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
;
// Plugin construction, PluginFactoryTensorRT owns the plugin.
PluginTensorRT
*
CreatePlugin
(
const
std
::
string
&
op_name
);
bool
RegisterPlugin
(
const
std
::
string
&
op_name
,
PluginDeserializeFunc
deserialize_func
,
PluginConstructFunc
construct_func
);
bool
IsPlugin
(
const
std
::
string
&
op_name
)
{
return
plugin_registry_
.
find
(
op_name
)
!=
plugin_registry_
.
end
();
}
size_t
CountOwnedPlugins
()
{
return
owned_plugins_
.
size
();
}
void
DestroyPlugins
();
protected:
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
PluginDeserializeFunc
,
PluginConstructFunc
>>
plugin_registry_
;
std
::
vector
<
std
::
unique_ptr
<
PluginTensorRT
>>
owned_plugins_
;
};
class
TrtPluginRegistrar
{
public:
TrtPluginRegistrar
(
const
std
::
string
&
name
,
PluginDeserializeFunc
deserialize_func
,
PluginConstructFunc
construct_func
)
{
auto
factory
=
PluginFactoryTensorRT
::
GetInstance
();
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
// construct_func), "Falied to register plugin [%s]", name);
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
// construct_func));
factory
->
RegisterPlugin
(
name
,
deserialize_func
,
construct_func
);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \
REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
construct_func)
#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
construct_func) \
REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
static ::paddle::inference::tensorrt::TrtPluginRegistrar \
trt_plugin_registrar##ctr __attribute__((unused)) = \
::paddle::inference::tensorrt::TrtPluginRegistrar( \
name, deserialize_func, construct_func)
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/plugin_utils.h
已删除
100644 → 0
浏览文件 @
2d7134bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
typedef
std
::
function
<
PluginTensorRT
*
(
const
void
*
,
size_t
)
>
PluginDeserializeFunc
;
typedef
std
::
function
<
PluginTensorRT
*
(
void
)
>
PluginConstructFunc
;
std
::
string
ExtractOpName
(
const
void
*
serial_data
,
size_t
serial_length
,
size_t
*
incremental
);
}
// namespace tensorrt
}
// namespace inference
}
// namespze paddle
paddle/fluid/inference/tensorrt/plugin/serialize.h
pp
→
paddle/fluid/inference/tensorrt/plugin/serialize.h
浏览文件 @
d38fd6a0
文件已移动
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
浏览文件 @
d38fd6a0
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
...
@@ -19,8 +20,6 @@ namespace paddle {
...
@@ -19,8 +20,6 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
SplitPlugin
*
CreateSplitPlugin
()
{
return
new
SplitPlugin
();
};
nvinfer1
::
Dims
SplitPlugin
::
getOutputDimensions
(
int
index
,
nvinfer1
::
Dims
SplitPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
int
nbInputs
)
{
...
@@ -28,15 +27,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
...
@@ -28,15 +27,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
assert
(
index
<
this
->
getNbOutputs
());
assert
(
index
<
this
->
getNbOutputs
());
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
nvinfer1
::
Dims
output_dims
=
input_dims
;
output_dims
.
d
[
axis_
]
=
output_leng
ht
_
.
at
(
index
);
output_dims
.
d
[
axis_
]
=
output_leng
th
_
.
at
(
index
);
return
output_dims
;
return
output_dims
;
}
}
int
SplitPlugin
::
initialize
()
{
int
SplitPlugin
::
initialize
()
{
std
::
vector
<
int
>
segment_offsets
(
1
,
0
);
std
::
vector
<
int
>
segment_offsets
(
1
,
0
);
for
(
int
i
=
0
;
i
<
this
->
getNbOutputs
();
++
i
)
{
for
(
int
i
=
0
;
i
<
this
->
getNbOutputs
();
++
i
)
{
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_leng
ht
_
[
i
]);
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_leng
th
_
[
i
]);
}
}
segment_offsets_
=
segment_offsets
;
d_segment_offsets_
=
segment_offsets
;
d_segment_offsets_
=
segment_offsets
;
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nx_
=
1
;
nx_
=
1
;
...
@@ -51,60 +51,30 @@ int SplitPlugin::initialize() {
...
@@ -51,60 +51,30 @@ int SplitPlugin::initialize() {
return
0
;
return
0
;
}
}
template
<
typename
T
>
__device__
int
upper_bound
(
T
const
*
vals
,
int
n
,
T
const
&
key
)
{
int
i
=
0
;
while
(
n
>
0
)
{
int
m
=
n
/
2
;
int
j
=
i
+
m
;
if
(
!
(
key
<
vals
[
j
]))
{
i
=
j
+
1
;
n
-=
m
+
1
;
}
else
{
n
=
m
;
}
}
return
i
;
}
template
<
typename
T
>
__global__
void
split_kernel
(
int
nsegment
,
int
const
*
__restrict__
segment_offsets
,
T
const
*
__restrict__
idata
,
T
*
const
*
odatas
,
int
nx
,
int
srcny_
,
int
nz
)
{
int
x0
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
src_y0
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
int
z0
=
threadIdx
.
z
+
blockIdx
.
z
*
blockDim
.
z
;
for
(
int
z
=
z0
;
z
<
nz
;
z
+=
blockDim
.
z
*
gridDim
.
z
)
{
for
(
int
src_y
=
src_y0
;
src_y
<
srcny_
;
src_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
for
(
int
x
=
x0
;
x
<
nx
;
x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
segment
=
upper_bound
(
segment_offsets
,
nsegment
,
src_y
)
-
1
;
int
dst_y
=
src_y
-
segment_offsets
[
segment
];
int
dstny_
=
segment_offsets
[
segment
+
1
]
-
segment_offsets
[
segment
];
odatas
[
segment
][
x
+
nx
*
(
dst_y
+
dstny_
*
z
)]
=
idata
[
x
+
nx
*
(
src_y
+
srcny_
*
z
)];
}
}
}
}
int
SplitPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
int
SplitPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
int
const
*
d_segment_offsets_ptr
=
int
const
*
d_segment_offsets_ptr
=
thrust
::
raw_pointer_cast
(
&
d_segment_offsets_
[
0
]);
thrust
::
raw_pointer_cast
(
&
d_segment_offsets_
[
0
]);
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
int
nz
=
nz_
*
batchSize
;
// kernel impl here.
dim3
block
(
32
,
16
);
int
inputBatchOffset
=
nx_
*
ny_
*
nz_
;
dim3
grid
(
std
::
min
((
nx_
-
1
)
/
block
.
x
+
1
,
65535u
),
for
(
size_t
i
=
0
;
i
<
this
->
getNbOutputs
();
i
++
)
{
std
::
min
((
ny_
-
1
)
/
block
.
y
+
1
,
65535u
),
for
(
size_t
j
=
0
;
j
<
batchSize
;
j
++
)
{
std
::
min
((
nz_
-
1
)
/
block
.
z
+
1
,
65535u
));
cudaMemcpyAsync
(
odatas
[
i
]
+
split_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
d_segment_offsets_
.
size
(),
j
*
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
nx_
*
d_segment_offsets_ptr
,
idata
,
odatas
,
sizeof
(
float
),
nx_
,
ny_
,
nz
);
inputs
[
0
]
+
(
inputBatchOffset
*
j
+
segment_offsets_
[
i
]
*
nx_
)
*
sizeof
(
float
),
(
segment_offsets_
[
i
+
1
]
-
segment_offsets_
[
i
])
*
nx_
*
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
);
}
}
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
浏览文件 @
d38fd6a0
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include <thrust/device_vector.h>
#include <thrust/device_vector.h>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -10,53 +23,55 @@ namespace tensorrt {
...
@@ -10,53 +23,55 @@ namespace tensorrt {
class
SplitPlugin
:
public
PluginTensorRT
{
class
SplitPlugin
:
public
PluginTensorRT
{
int
axis_
;
int
axis_
;
std
::
vector
<
int
>
output_leng
ht
_
;
std
::
vector
<
int
>
output_leng
th
_
;
int
nx_
,
ny_
,
nz_
;
int
nx_
,
ny_
,
nz_
;
thrust
::
device_vector
<
int
>
d_segment_offsets_
;
thrust
::
device_vector
<
int
>
d_segment_offsets_
;
std
::
vector
<
int
>
segment_offsets_
;
protected:
protected:
virtual
size_t
getSerializationSize
()
override
{
virtual
size_t
getSerializationSize
()
override
{
return
serialized_size
(
axis_
)
+
serialized_size
(
output_leng
ht_
)
return
serialized_size
(
axis_
)
+
serialized_size
(
output_leng
th_
)
+
+
getBaseSerializationSize
();
getBaseSerializationSize
();
}
}
virtual
void
serialize
(
void
*
buffer
)
override
{
virtual
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
serializeBase
(
buffer
);
serialize_value
(
&
buffer
,
axis_
);
serialize_value
(
&
buffer
,
axis_
);
serialize_value
(
&
buffer
,
output_leng
ht
_
);
serialize_value
(
&
buffer
,
output_leng
th
_
);
}
}
public:
public:
Split
()
{}
SplitPlugin
(
int
axis
,
std
::
vector
<
int
>
const
&
output_lengths
)
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
:
axis_
(
axis
),
output_length_
(
output_lengths
)
{
assert
(
axis
<=
nvinfer1
::
Dims
::
MAX_DIMS
);
}
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
deserializeBase
(
serialData
,
serialLength
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
axis_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
axis_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
output_leng
ht
_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
output_leng
th
_
);
}
}
SplitPlugin
*
clone
()
const
override
{
SplitPlugin
*
clone
()
const
override
{
return
new
SplitPlugin
(
axis_
,
output_leng
ht
_
);
return
new
SplitPlugin
(
axis_
,
output_leng
th
_
);
}
}
virtual
const
char
*
getPluginType
()
const
override
{
return
"split"
;
}
virtual
const
char
*
getPluginType
()
const
override
{
return
"split"
;
}
virtual
int
getNbOutputs
()
const
override
{
return
output_leng
ht
_
.
size
();
}
virtual
int
getNbOutputs
()
const
override
{
return
output_leng
th
_
.
size
();
}
virtual
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
virtual
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
virtual
int
initialize
()
override
;
virtual
int
initialize
()
override
;
virtual
int
enqueue
(
int
batchSize
,
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
setAxis
(
int
axis
)
{
void
setAxis
(
int
axis
)
{
axis_
=
axis
;
}
axis_
=
axis
;
}
void
setOutputLengths
(
const
std
::
vector
<
int
>
&
output_lengths
)
{
void
setOutputLengths
(
const
std
::
vector
<
int
>
&
output_lengths
)
{
output_length_
=
output_lengths
;
output_length_
=
output_lengths
;
}
}
};
};
}
// tensorrt
}
// tensorrt
}
// inference
}
// inference
}
// paddle
}
// paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
浏览文件 @
d38fd6a0
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -41,8 +40,7 @@ size_t PluginTensorRT::getBaseSerializationSize() {
...
@@ -41,8 +40,7 @@ size_t PluginTensorRT::getBaseSerializationSize() {
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
{
nvinfer1
::
PluginFormat
format
)
const
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kNCHW
));
(
format
==
nvinfer1
::
PluginFormat
::
kNCHW
));
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
浏览文件 @
d38fd6a0
...
@@ -14,14 +14,14 @@
...
@@ -14,14 +14,14 @@
#pragma once
#pragma once
#include <NvInfer.h>
#include <cassert>
#include <cassert>
#include <cstring>
#include <cstring>
#include <iostream>
#include <iostream>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h
pp
"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -53,8 +53,8 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
...
@@ -53,8 +53,8 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1
::
DataType
type
,
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
override
;
int
maxBatchSize
)
override
;
virtual
void
serialize
(
void
*
buffer
)
override
;
virtual
void
serialize
(
void
*
buffer
)
=
0
;
virtual
size_t
getSerializationSize
()
override
;
virtual
size_t
getSerializationSize
()
=
0
;
protected:
protected:
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录