Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
TonyTonyFun
Paddle
提交
af97b310
P
Paddle
项目概览
TonyTonyFun
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
af97b310
编写于
6月 24, 2022
作者:
C
ccrrong
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add slice plugin int32 support (#43808)
* add slice plugin int32 support
上级
eec4e034
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
122 addition
and
30 deletion
+122
-30
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
+71
-30
python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py
...uid/tests/unittests/ir/inference/test_trt_slice_plugin.py
+51
-0
未找到文件。
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
浏览文件 @
af97b310
...
...
@@ -28,8 +28,8 @@ namespace tensorrt {
namespace
plugin
{
template
<
typename
T
>
__global__
void
SliceKernel
(
int
num
,
int
dims
,
const
T
*
input
,
const
int
*
offsets_info
,
T
*
output
)
{
__global__
void
SliceKernel
(
int
num
,
int
dims
,
const
T
*
input
,
const
int
*
offsets_info
,
T
*
output
)
{
const
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
extern
__shared__
int
shared_data
[];
...
...
@@ -54,8 +54,10 @@ __global__ void SliceKernel(int num, int dims, const T *input,
}
}
SlicePlugin
::
SlicePlugin
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
)
SlicePlugin
::
SlicePlugin
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
)
{
with_fp16_
=
with_fp16
;
}
...
...
@@ -79,10 +81,12 @@ bool SlicePlugin::supportsFormat(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
TRT_NOEXCEPT
{
if
(
with_fp16_
)
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
type
==
nvinfer1
::
DataType
::
kHALF
||
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
else
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
}
...
...
@@ -99,11 +103,15 @@ nvinfer1::Dims SlicePlugin::getOutputDimensions(
return
out_dims
;
}
int
SlicePlugin
::
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
int
SlicePlugin
::
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
#if IS_TRT_VERSION_LT(8000)
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
#else
void
*
const
*
outputs
,
void
*
workspace
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
#endif
auto
input_dims
=
getInputDims
(
0
);
...
...
@@ -153,8 +161,11 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
cudaMalloc
(
&
offset_temp_data_
,
3
*
num_dims
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
int
threads
=
256
;
int
blocks
=
(
out_num
+
threads
-
1
)
/
threads
;
...
...
@@ -171,9 +182,15 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SliceKernel
<
half
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kINT32
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->int32"
;
const
int
*
input1
=
static_cast
<
const
int
*>
(
inputs
[
0
]);
int
*
output
=
static_cast
<
int
*>
(
outputs
[
0
]);
SliceKernel
<
int
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Slice TRT Plugin's input type should be float
or half
."
));
"The Slice TRT Plugin's input type should be float
, half or int
."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
@@ -197,7 +214,8 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic
::
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
),
decrease_axis_
(
decrease_axis
)
{
with_fp16_
=
with_fp16
;
...
...
@@ -238,7 +256,9 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
}
nvinfer1
::
DimsExprs
SlicePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
auto
in_dims
=
inputs
[
0
];
nvinfer1
::
DimsExprs
ret
=
in_dims
;
...
...
@@ -264,7 +284,8 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
for
(
size_t
i
=
0
;
i
<
in_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axis_
==
i
)
continue
;
res
.
d
[
j
++
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kMAX
,
*
expr_builder
.
constant
(
0
),
*
ret
.
d
[
i
]);
*
expr_builder
.
constant
(
0
),
*
ret
.
d
[
i
]);
}
return
res
;
}
...
...
@@ -272,26 +293,33 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
}
bool
SlicePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
in
.
type
==
nvinfer1
::
DataType
::
kHALF
||
in
.
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
...
...
@@ -301,24 +329,28 @@ bool SlicePluginDynamic::supportsFormatCombination(
}
nvinfer1
::
DataType
SlicePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Slice Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kHALF
),
input_types
[
0
]
==
nvinfer1
::
DataType
::
kHALF
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kINT32
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input type should be half
or floa
t"
));
"The input type should be half
, float or in
t"
));
return
input_types
[
0
];
}
int
SlicePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
...
...
@@ -362,8 +394,11 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaMalloc
(
&
offset_temp_data_
,
3
*
num_dims
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info_
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info_
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
int
threads
=
256
;
int
blocks
=
(
out_num
+
threads
-
1
)
/
threads
;
...
...
@@ -380,9 +415,15 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SliceKernel
<
half
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kINT32
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->int32"
;
const
int
*
input1
=
static_cast
<
const
int
*>
(
inputs
[
0
]);
int
*
output
=
static_cast
<
int
*>
(
outputs
[
0
]);
SliceKernel
<
int
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Slice TRT Plugin's input type should be float
or half
."
));
"The Slice TRT Plugin's input type should be float
, half or int
."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_slice_plugin.py
浏览文件 @
af97b310
...
...
@@ -108,5 +108,56 @@ class StaticSlicePluginTRTTestFp32(SlicePluginTRTTest):
self
.
enable_trt
=
True
class
SlicePluginTRTTestInt32
(
SlicePluginTRTTest
):
def
setUp
(
self
):
self
.
setUpSliceParams
()
self
.
setUpTensorRTParams
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
3
,
3
,
3
,
3
],
dtype
=
"int32"
)
axes
=
self
.
params_axes
starts
=
self
.
params_starts
ends
=
self
.
params_ends
slice_out
=
fluid
.
layers
.
slice
(
data
,
axes
=
axes
,
starts
=
starts
,
ends
=
ends
)
cast_out
=
fluid
.
layers
.
cast
(
slice_out
,
'float32'
)
out
=
fluid
.
layers
.
batch_norm
(
cast_out
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
3
,
3
,
3
,
3
)).
astype
(
"int32"
),
}
self
.
fetch_list
=
[
out
]
class
StaticSlicePluginTRTTestInt32
(
SlicePluginTRTTest
):
def
setUpTensorRTParams
(
self
):
self
.
trt_parameters
=
SlicePluginTRTTest
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
AnalysisConfig
.
Precision
.
Float32
,
True
,
False
)
self
.
enable_trt
=
True
def
setUp
(
self
):
self
.
setUpSliceParams
()
self
.
setUpTensorRTParams
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
3
,
3
,
3
,
3
],
dtype
=
"int32"
)
axes
=
self
.
params_axes
starts
=
self
.
params_starts
ends
=
self
.
params_ends
slice_out
=
fluid
.
layers
.
slice
(
data
,
axes
=
axes
,
starts
=
starts
,
ends
=
ends
)
cast_out
=
fluid
.
layers
.
cast
(
slice_out
,
'float32'
)
out
=
fluid
.
layers
.
batch_norm
(
cast_out
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
3
,
3
,
3
,
3
)).
astype
(
"int32"
),
}
self
.
fetch_list
=
[
out
]
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录