Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2d7134bc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2d7134bc
编写于
11月 13, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add initial code for plugin
上级
0b388226
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
653 addition
and
2 deletion
+653
-2
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/concat_op.cc
paddle/fluid/inference/tensorrt/convert/concat_op.cc
+1
-1
paddle/fluid/inference/tensorrt/plugin/.trt_plugin_utils.h.swp
...e/fluid/inference/tensorrt/plugin/.trt_plugin_utils.h.swp
+0
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+2
-0
paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc
paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc
+64
-0
paddle/fluid/inference/tensorrt/plugin/plugin_factory.h
paddle/fluid/inference/tensorrt/plugin/plugin_factory.h
+91
-0
paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc
paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc
+37
-0
paddle/fluid/inference/tensorrt/plugin/plugin_utils.h
paddle/fluid/inference/tensorrt/plugin/plugin_utils.h
+34
-0
paddle/fluid/inference/tensorrt/plugin/serialize.hpp
paddle/fluid/inference/tensorrt/plugin/serialize.hpp
+111
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+114
-0
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+62
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+63
-0
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+72
-0
未找到文件。
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
2d7134bc
nv_library
(
tensorrt_engine SRCS engine.cc DEPS framework_proto device_context
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine
)
add_subdirectory
(
plugin
)
add_subdirectory
(
convert
)
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
2d7134bc
...
...
@@ -2,7 +2,7 @@
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry
)
DEPS tensorrt_engine
tensorrt_plugin
operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_converter
)
...
...
paddle/fluid/inference/tensorrt/convert/concat_op.cc
浏览文件 @
2d7134bc
...
...
@@ -19,7 +19,7 @@ namespace inference {
namespace
tensorrt
{
/*
*
MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*
ConcatOp
*/
class
ConcatOpConverter
:
public
OpConverter
{
public:
...
...
paddle/fluid/inference/tensorrt/plugin/.trt_plugin_utils.h.swp
0 → 100644
浏览文件 @
2d7134bc
文件已添加
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
0 → 100644
浏览文件 @
2d7134bc
nv_library
(
tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc
trt_plugin.cc split_op_plugin.cu DEPS enforce
)
paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc
0 → 100644
浏览文件 @
2d7134bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
PluginTensorRT
*
PluginFactoryTensorRT
::
createPlugin
(
const
char
*
layer_name
,
const
void
*
serial_data
,
size_t
serial_length
)
{
size_t
parsed_byte
=
0
;
std
::
string
encoded_op_name
=
ExtractOpName
(
serial_data
,
serial_length
,
&
parsed_byte
);
if
(
!
IsPlugin
(
encoded_op_name
))
{
return
nullptr
;
}
auto
plugin_ptr
=
plugin_registry_
[
encoded_op_name
].
first
(
serial_data
,
serial_length
);
owned_plugins_
.
emplace_back
(
plugin_ptr
);
return
plugin_ptr
;
}
PluginTensorRT
*
PluginFactoryTensorRT
::
CreatePlugin
(
const
std
::
string
&
op_name
)
{
if
(
!
IsPlugin
(
op_name
))
return
nullptr
;
auto
plugin_ptr
=
plugin_registry_
[
op_name
].
second
();
owned_plugins_
.
emplace_back
(
plugin_ptr
);
return
plugin_ptr
;
}
bool
PluginFactoryTensorRT
::
RegisterPlugin
(
const
std
::
string
&
op_name
,
PluginDeserializeFunc
deserialize_func
,
PluginConstructFunc
construct_func
)
{
if
(
IsPlugin
(
op_name
))
return
false
;
auto
ret
=
plugin_registry_
.
emplace
(
op_name
,
std
::
make_pair
(
deserialize_func
,
construct_func
));
return
ret
.
second
;
}
void
PluginFactoryTensorRT
::
DestroyPlugins
()
{
owned_plugins_
.
clear
();
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/plugin_factory.h
0 → 100644
浏览文件 @
2d7134bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <unordered_map>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
PluginFactoryTensorRT
:
public
nvinfer1
::
IPluginFactory
{
public:
static
PluginFactoryTensorRT
*
GetInstance
()
{
static
PluginFactoryTensorRT
*
factory_instance
=
new
PluginFactoryTensorRT
();
return
factory_instance
;
}
// Deserialization method
PluginTensorRT
*
createPlugin
(
const
char
*
layer_name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
;
// Plugin construction, PluginFactoryTensorRT owns the plugin.
PluginTensorRT
*
CreatePlugin
(
const
std
::
string
&
op_name
);
bool
RegisterPlugin
(
const
std
::
string
&
op_name
,
PluginDeserializeFunc
deserialize_func
,
PluginConstructFunc
construct_func
);
bool
IsPlugin
(
const
std
::
string
&
op_name
)
{
return
plugin_registry_
.
find
(
op_name
)
!=
plugin_registry_
.
end
();
}
size_t
CountOwnedPlugins
()
{
return
owned_plugins_
.
size
();
}
void
DestroyPlugins
();
protected:
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
PluginDeserializeFunc
,
PluginConstructFunc
>>
plugin_registry_
;
std
::
vector
<
std
::
unique_ptr
<
PluginTensorRT
>>
owned_plugins_
;
};
class
TrtPluginRegistrar
{
public:
TrtPluginRegistrar
(
const
std
::
string
&
name
,
PluginDeserializeFunc
deserialize_func
,
PluginConstructFunc
construct_func
)
{
auto
factory
=
PluginFactoryTensorRT
::
GetInstance
();
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
// construct_func), "Falied to register plugin [%s]", name);
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
// construct_func));
factory
->
RegisterPlugin
(
name
,
deserialize_func
,
construct_func
);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \
REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
construct_func)
#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
construct_func) \
REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
static ::paddle::inference::tensorrt::TrtPluginRegistrar \
trt_plugin_registrar##ctr __attribute__((unused)) = \
::paddle::inference::tensorrt::TrtPluginRegistrar( \
name, deserialize_func, construct_func)
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc
0 → 100644
浏览文件 @
2d7134bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
#include <cassert>
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
std
::
string
ExtractOpName
(
const
void
*
serial_data
,
size_t
serial_length
,
size_t
*
incremental
)
{
size_t
op_name_char_count
=
*
static_cast
<
const
size_t
*>
(
serial_data
);
*
incremental
=
sizeof
(
size_t
)
+
op_name_char_count
;
assert
(
serial_length
>=
*
incremental
);
const
char
*
buffer
=
static_cast
<
const
char
*>
(
serial_data
)
+
sizeof
(
size_t
);
std
::
string
op_name
(
buffer
,
op_name_char_count
);
return
op_name
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/plugin_utils.h
0 → 100644
浏览文件 @
2d7134bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
typedef
std
::
function
<
PluginTensorRT
*
(
const
void
*
,
size_t
)
>
PluginDeserializeFunc
;
typedef
std
::
function
<
PluginTensorRT
*
(
void
)
>
PluginConstructFunc
;
std
::
string
ExtractOpName
(
const
void
*
serial_data
,
size_t
serial_length
,
size_t
*
incremental
);
}
// namespace tensorrt
}
// namespace inference
}
// namespze paddle
paddle/fluid/inference/tensorrt/plugin/serialize.hpp
0 → 100644
浏览文件 @
2d7134bc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <type_traits>
#include <vector>
template
<
typename
T
>
inline
void
serialize_value
(
void
**
buffer
,
T
const
&
value
);
template
<
typename
T
>
inline
void
deserialize_value
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
);
namespace
{
template
<
typename
T
,
class
Enable
=
void
>
struct
Serializer
{};
template
<
typename
T
>
struct
Serializer
<
T
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
serialized_size
(
T
const
&
value
)
{
return
sizeof
(
T
);
}
static
void
serialize
(
void
**
buffer
,
T
const
&
value
)
{
::
memcpy
(
*
buffer
,
&
value
,
sizeof
(
T
));
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
sizeof
(
T
);
}
static
void
deserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
assert
(
*
buffer_size
>=
sizeof
(
T
));
::
memcpy
(
value
,
*
buffer
,
sizeof
(
T
));
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
sizeof
(
T
);
*
buffer_size
-=
sizeof
(
T
);
}
};
template
<
>
struct
Serializer
<
const
char
*>
{
static
size_t
serialized_size
(
const
char
*
value
)
{
return
strlen
(
value
)
+
1
;
}
static
void
serialize
(
void
**
buffer
,
const
char
*
value
)
{
::
strcpy
(
static_cast
<
char
*>
(
*
buffer
),
value
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
strlen
(
value
)
+
1
;
}
static
void
deserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
const
char
**
value
)
{
*
value
=
static_cast
<
char
const
*>
(
*
buffer
);
size_t
data_size
=
strnlen
(
*
value
,
*
buffer_size
)
+
1
;
assert
(
*
buffer_size
>=
data_size
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
data_size
;
*
buffer_size
-=
data_size
;
}
};
template
<
typename
T
>
struct
Serializer
<
std
::
vector
<
T
>
,
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
||
std
::
is_enum
<
T
>::
value
||
std
::
is_pod
<
T
>::
value
>::
type
>
{
static
size_t
serialized_size
(
std
::
vector
<
T
>
const
&
value
)
{
return
sizeof
(
value
.
size
())
+
value
.
size
()
*
sizeof
(
T
);
}
static
void
serialize
(
void
**
buffer
,
std
::
vector
<
T
>
const
&
value
)
{
serialize_value
(
buffer
,
value
.
size
());
size_t
nbyte
=
value
.
size
()
*
sizeof
(
T
);
::
memcpy
(
*
buffer
,
value
.
data
(),
nbyte
);
reinterpret_cast
<
char
*&>
(
*
buffer
)
+=
nbyte
;
}
static
void
deserialize
(
void
const
**
buffer
,
size_t
*
buffer_size
,
std
::
vector
<
T
>*
value
)
{
size_t
size
;
deserialize_value
(
buffer
,
buffer_size
,
&
size
);
value
->
resize
(
size
);
size_t
nbyte
=
value
->
size
()
*
sizeof
(
T
);
assert
(
*
buffer_size
>=
nbyte
);
::
memcpy
(
value
->
data
(),
*
buffer
,
nbyte
);
reinterpret_cast
<
char
const
*&>
(
*
buffer
)
+=
nbyte
;
*
buffer_size
-=
nbyte
;
}
};
}
// namespace
template
<
typename
T
>
inline
size_t
serialized_size
(
T
const
&
value
)
{
return
Serializer
<
T
>::
serialized_size
(
value
);
}
template
<
typename
T
>
inline
void
serialize_value
(
void
**
buffer
,
T
const
&
value
)
{
return
Serializer
<
T
>::
serialize
(
buffer
,
value
);
}
template
<
typename
T
>
inline
void
deserialize_value
(
void
const
**
buffer
,
size_t
*
buffer_size
,
T
*
value
)
{
return
Serializer
<
T
>::
deserialize
(
buffer
,
buffer_size
,
value
);
}
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
0 → 100644
浏览文件 @
2d7134bc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
SplitPlugin
*
CreateSplitPlugin
()
{
return
new
SplitPlugin
();
};
nvinfer1
::
Dims
SplitPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
assert
(
nbInputs
==
1
);
assert
(
index
<
this
->
getNbOutputs
());
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
output_dims
.
d
[
axis_
]
=
output_lenght_
.
at
(
index
);
return
output_dims
;
}
int
SplitPlugin
::
initialize
()
{
std
::
vector
<
int
>
segment_offsets
(
1
,
0
);
for
(
int
i
=
0
;
i
<
this
->
getNbOutputs
();
++
i
)
{
segment_offsets
.
push_back
(
segment_offsets
.
back
()
+
output_lenght_
[
i
]);
}
d_segment_offsets_
=
segment_offsets
;
nvinfer1
::
Dims
dims
=
this
->
getInputDims
(
0
);
nx_
=
1
;
for
(
int
i
=
dims
.
nbDims
-
1
;
i
>
axis_
;
--
i
)
{
nx_
*=
dims
.
d
[
i
];
}
ny_
=
dims
.
d
[
axis_
];
nz_
=
1
;
for
(
int
i
=
axis_
-
1
;
i
>=
0
;
--
i
)
{
nz_
*=
dims
.
d
[
i
];
}
return
0
;
}
template
<
typename
T
>
__device__
int
upper_bound
(
T
const
*
vals
,
int
n
,
T
const
&
key
)
{
int
i
=
0
;
while
(
n
>
0
)
{
int
m
=
n
/
2
;
int
j
=
i
+
m
;
if
(
!
(
key
<
vals
[
j
]))
{
i
=
j
+
1
;
n
-=
m
+
1
;
}
else
{
n
=
m
;
}
}
return
i
;
}
template
<
typename
T
>
__global__
void
split_kernel
(
int
nsegment
,
int
const
*
__restrict__
segment_offsets
,
T
const
*
__restrict__
idata
,
T
*
const
*
odatas
,
int
nx
,
int
srcny_
,
int
nz
)
{
int
x0
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
src_y0
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
int
z0
=
threadIdx
.
z
+
blockIdx
.
z
*
blockDim
.
z
;
for
(
int
z
=
z0
;
z
<
nz
;
z
+=
blockDim
.
z
*
gridDim
.
z
)
{
for
(
int
src_y
=
src_y0
;
src_y
<
srcny_
;
src_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
for
(
int
x
=
x0
;
x
<
nx
;
x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
segment
=
upper_bound
(
segment_offsets
,
nsegment
,
src_y
)
-
1
;
int
dst_y
=
src_y
-
segment_offsets
[
segment
];
int
dstny_
=
segment_offsets
[
segment
+
1
]
-
segment_offsets
[
segment
];
odatas
[
segment
][
x
+
nx
*
(
dst_y
+
dstny_
*
z
)]
=
idata
[
x
+
nx
*
(
src_y
+
srcny_
*
z
)];
}
}
}
}
int
SplitPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
const
*
d_segment_offsets_ptr
=
thrust
::
raw_pointer_cast
(
&
d_segment_offsets_
[
0
]);
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
int
nz
=
nz_
*
batchSize
;
dim3
block
(
32
,
16
);
dim3
grid
(
std
::
min
((
nx_
-
1
)
/
block
.
x
+
1
,
65535u
),
std
::
min
((
ny_
-
1
)
/
block
.
y
+
1
,
65535u
),
std
::
min
((
nz_
-
1
)
/
block
.
z
+
1
,
65535u
));
split_kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
d_segment_offsets_
.
size
(),
d_segment_offsets_ptr
,
idata
,
odatas
,
nx_
,
ny_
,
nz
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// tensorrt
}
// inference
}
// paddle
paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
0 → 100644
浏览文件 @
2d7134bc
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include <thrust/device_vector.h>
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
SplitPlugin
:
public
PluginTensorRT
{
int
axis_
;
std
::
vector
<
int
>
output_lenght_
;
int
nx_
,
ny_
,
nz_
;
thrust
::
device_vector
<
int
>
d_segment_offsets_
;
protected:
virtual
size_t
getSerializationSize
()
override
{
return
serialized_size
(
axis_
)
+
serialized_size
(
output_lenght_
)
+
getBaseSerializationSize
();
}
virtual
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
serialize_value
(
&
buffer
,
axis_
);
serialize_value
(
&
buffer
,
output_lenght_
);
}
public:
Split
()
{}
SplitPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
axis_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
output_lenght_
);
}
SplitPlugin
*
clone
()
const
override
{
return
new
SplitPlugin
(
axis_
,
output_lenght_
);
}
virtual
const
char
*
getPluginType
()
const
override
{
return
"split"
;
}
virtual
int
getNbOutputs
()
const
override
{
return
output_lenght_
.
size
();
}
virtual
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
virtual
int
initialize
()
override
;
virtual
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
void
setAxis
(
int
axis
)
{
axis_
=
axis
;
}
void
setOutputLengths
(
const
std
::
vector
<
int
>
&
output_lengths
)
{
output_length_
=
output_lengths
;
}
};
}
// tensorrt
}
// inference
}
// paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
0 → 100644
浏览文件 @
2d7134bc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
void
PluginTensorRT
::
serializeBase
(
void
*&
buffer
)
{
serialize_value
(
&
buffer
,
input_dims_
);
serialize_value
(
&
buffer
,
max_batch_size_
);
serialize_value
(
&
buffer
,
data_type_
);
serialize_value
(
&
buffer
,
data_format_
);
}
void
PluginTensorRT
::
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
)
{
deserialize_value
(
&
serialData
,
&
serialLength
,
&
input_dims_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
max_batch_size_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
data_type_
);
deserialize_value
(
&
serialData
,
&
serialLength
,
&
data_format_
);
}
size_t
PluginTensorRT
::
getBaseSerializationSize
()
{
return
(
serialized_size
(
input_dims_
)
+
serialized_size
(
max_batch_size_
)
+
serialized_size
(
data_type_
)
+
serialized_size
(
data_format_
));
}
bool
PluginTensorRT
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kNCHW
));
}
void
PluginTensorRT
::
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
const
nvinfer1
::
Dims
*
outputDims
,
int
nbOutputs
,
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
{
data_type_
=
type
;
data_format_
=
format
;
input_dims_
.
assign
(
inputDims
,
inputDims
+
nbInputs
);
max_batch_size_
=
maxBatchSize
;
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
0 → 100644
浏览文件 @
2d7134bc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <cassert>
#include <cstring>
#include <iostream>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/serialize.hpp"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
PluginTensorRT
:
public
nvinfer1
::
IPluginExt
{
public:
PluginTensorRT
()
{}
PluginTensorRT
(
const
void
*
serialized_data
,
size_t
length
)
{}
nvinfer1
::
Dims
const
&
getInputDims
(
int
index
)
const
{
return
input_dims_
.
at
(
index
);
}
size_t
getMaxBatchSize
()
const
{
return
max_batch_size_
;
}
nvinfer1
::
DataType
getDataType
()
const
{
return
data_type_
;
}
nvinfer1
::
PluginFormat
getDataFormat
()
const
{
return
data_format_
;
}
virtual
const
char
*
getPluginVersion
()
const
{
return
"1"
;
}
size_t
getWorkspaceSize
(
int
)
const
override
{
return
0
;
}
void
terminate
()
override
{}
virtual
~
PluginTensorRT
()
{}
// The following functions need to be overrided in the subclass.
virtual
nvinfer1
::
IPluginExt
*
clone
()
const
=
0
;
virtual
const
char
*
getPluginType
()
const
=
0
;
int
initialize
()
override
{
return
0
;
}
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
override
;
void
configureWithFormat
(
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
,
const
nvinfer1
::
Dims
*
outputDims
,
int
nbOutputs
,
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
,
int
maxBatchSize
)
override
;
virtual
void
serialize
(
void
*
buffer
)
override
;
virtual
size_t
getSerializationSize
()
override
;
protected:
void
deserializeBase
(
void
const
*&
serialData
,
size_t
&
serialLength
);
size_t
getBaseSerializationSize
();
void
serializeBase
(
void
*&
buffer
);
std
::
vector
<
nvinfer1
::
Dims
>
input_dims_
;
size_t
max_batch_size_
;
nvinfer1
::
DataType
data_type_
;
nvinfer1
::
PluginFormat
data_format_
;
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录