Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cdd7b956
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cdd7b956
编写于
11月 09, 2022
作者:
W
Wangzheee
提交者:
GitHub
11月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]upgrade scale and slice op convert for Paddle-TensorRT (#47746)
* upgrade scale and slice op convert for Paddle-TensorRT
上级
1aa64d13
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
294 addition
and
822 deletion
+294
-822
paddle/fluid/inference/tensorrt/convert/scale_op.cc
paddle/fluid/inference/tensorrt/convert/scale_op.cc
+163
-91
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+90
-82
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+36
-24
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+0
-1
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
+0
-435
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
+0
-187
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_clip.py
...uid/tests/unittests/ir/inference/test_trt_convert_clip.py
+5
-2
未找到文件。
paddle/fluid/inference/tensorrt/convert/scale_op.cc
浏览文件 @
cdd7b956
...
...
@@ -49,6 +49,77 @@ class ScaleOpConverter : public OpConverter {
PADDLE_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"bias_after_scale"
));
float
bias
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"bias"
));
float
scale
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"scale"
));
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
nvinfer1
::
ITensor
*
bias_tensor
=
Add1DConstantLayer
(
bias
);
bool
is_bias_0
=
(
bias
<
1e-06
&&
bias
>
-
1e-06
);
std
::
vector
<
int32_t
>
bias_shapes
(
input
->
getDimensions
().
nbDims
,
1
);
auto
*
bias_shapes_tensor
=
Add1DConstantLayer
(
bias_shapes
);
auto
*
reshape_layer_bias
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
bias_tensor
);
reshape_layer_bias
->
setInput
(
1
,
*
bias_shapes_tensor
);
bool
has_scale_tensor
;
nvinfer1
::
ITensor
*
scale_tensor
;
bool
is_scale_1
;
auto
scale_inputs
=
op_desc
.
Inputs
();
if
(
scale_inputs
.
find
(
"ScaleTensor"
)
!=
scale_inputs
.
end
()
&&
op_desc
.
Input
(
"ScaleTensor"
).
size
())
{
// has EndsTensor input
has_scale_tensor
=
true
;
scale_tensor
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"ScaleTensor"
)[
0
]);
is_scale_1
=
false
;
}
else
{
has_scale_tensor
=
false
;
scale_tensor
=
Add1DConstantLayer
(
scale
);
is_scale_1
=
((
scale
-
1.0
)
<
1e-06
&&
(
scale
-
1.0
)
>
-
1e-06
);
}
std
::
vector
<
int32_t
>
scale_shapes
(
input
->
getDimensions
().
nbDims
,
1
);
auto
*
scale_shapes_tensor
=
Add1DConstantLayer
(
scale_shapes
);
auto
*
reshape_layer_scale
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
scale_tensor
);
reshape_layer_scale
->
setInput
(
1
,
*
scale_shapes_tensor
);
if
(
!
has_scale_tensor
&&
is_scale_1
&&
is_bias_0
)
{
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Identity
,
*
input
);
}
else
{
if
(
bias_after_scale
)
{
if
(
!
is_scale_1
)
{
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
input
,
*
reshape_layer_scale
->
getOutput
(
0
),
nvinfer1
::
ElementWiseOperation
::
kPROD
);
input
=
layer
->
getOutput
(
0
);
}
if
(
!
is_bias_0
)
{
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
input
,
*
reshape_layer_bias
->
getOutput
(
0
),
nvinfer1
::
ElementWiseOperation
::
kSUM
);
}
}
else
{
if
(
!
is_bias_0
)
{
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
input
,
*
reshape_layer_bias
->
getOutput
(
0
),
nvinfer1
::
ElementWiseOperation
::
kSUM
);
input
=
layer
->
getOutput
(
0
);
}
if
(
!
is_scale_1
)
{
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ElementWise
,
*
input
,
*
reshape_layer_scale
->
getOutput
(
0
),
nvinfer1
::
ElementWiseOperation
::
kPROD
);
}
}
}
}
else
{
auto
create_weights
=
[
&
](
float
data
,
std
::
string
type
)
->
float
*
{
std
::
unique_ptr
<
phi
::
DenseTensor
>
tmp_tensor
(
new
phi
::
DenseTensor
());
tmp_tensor
->
Resize
({
1
});
...
...
@@ -59,8 +130,6 @@ class ScaleOpConverter : public OpConverter {
return
tmp_data
;
};
int
dynamic_shape_offset
=
engine_
->
with_dynamic_shape
()
?
1
:
0
;
float
*
bias_ptr
=
create_weights
(
bias
,
"bias"
);
float
*
scale_ptr
=
create_weights
(
scale
,
"scale"
);
...
...
@@ -70,17 +139,16 @@ class ScaleOpConverter : public OpConverter {
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_ptr
),
1
};
TensorRTEngine
::
Weight
power_weights
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
nvinfer1
::
ILayer
*
layer
=
nullptr
;
auto
input_dim
=
input
->
getDimensions
();
nvinfer1
::
IShuffleLayer
*
expand_layer
=
nullptr
;
nvinfer1
::
IShuffleLayer
*
squeeze_layer
=
nullptr
;
if
(
input_dim
.
nbDims
<
3
+
dynamic_shape_offset
)
{
if
(
input_dim
.
nbDims
<
3
)
{
nvinfer1
::
Dims
expand_shape
;
expand_shape
.
nbDims
=
3
+
dynamic_shape_offset
;
for
(
int
i
=
0
;
i
<
3
+
dynamic_shape_offset
;
i
++
)
{
expand_shape
.
nbDims
=
3
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
if
(
i
<
input_dim
.
nbDims
)
{
expand_shape
.
d
[
i
]
=
input_dim
.
d
[
i
]
<
0
?
0
:
input_dim
.
d
[
i
];
}
else
{
...
...
@@ -118,7 +186,8 @@ class ScaleOpConverter : public OpConverter {
power_weights
.
get
());
layer
->
getOutput
(
0
)
->
setName
(
(
"bias_before_scale:bias_out: "
+
out_name
).
c_str
());
layer
->
setName
((
"Scale: scale_bias (Output: "
+
out_name
+
")"
).
c_str
());
layer
->
setName
(
(
"Scale: scale_bias (Output: "
+
out_name
+
")"
).
c_str
());
// mul scale
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
...
...
@@ -129,14 +198,15 @@ class ScaleOpConverter : public OpConverter {
power_weights
.
get
());
layer
->
getOutput
(
0
)
->
setName
(
(
"bias_before_scale:scale_out: "
+
out_name
).
c_str
());
layer
->
setName
((
"Scale: scale_scale (Output: "
+
out_name
+
")"
).
c_str
());
layer
->
setName
(
(
"Scale: scale_scale (Output: "
+
out_name
+
")"
).
c_str
());
}
PADDLE_ENFORCE_EQ
(
layer
!=
nullptr
,
true
,
platform
::
errors
::
Fatal
(
"Create scale layer failed."
));
if
(
input_dim
.
nbDims
<
3
+
dynamic_shape_offset
)
{
if
(
input_dim
.
nbDims
<
3
)
{
nvinfer1
::
Dims
squeeze_shape
;
squeeze_shape
.
nbDims
=
input_dim
.
nbDims
;
for
(
int
i
=
0
;
i
<
squeeze_shape
.
nbDims
;
i
++
)
{
...
...
@@ -146,10 +216,12 @@ class ScaleOpConverter : public OpConverter {
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
(
layer
->
getOutput
(
0
)));
squeeze_layer
->
setReshapeDimensions
(
squeeze_shape
);
layer
=
static_cast
<
nvinfer1
::
ILayer
*>
(
squeeze_layer
);
layer
->
getOutput
(
0
)
->
setName
((
"after_reshape_out: "
+
out_name
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
(
"after_reshape_out: "
+
out_name
).
c_str
());
layer
->
setName
(
(
"Scale: Shuffle_reshape (Output: "
+
out_name
+
")"
).
c_str
());
}
}
RreplenishLayerAndOutput
(
layer
,
"scale"
,
{
out_name
},
test_mode
);
}
};
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
cdd7b956
...
...
@@ -10,7 +10,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -34,7 +33,6 @@ class SliceOpConverter : public OpConverter {
out_scale
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
engine_
->
SetTensorDynamicRange
(
input
,
out_scale
);
}
std
::
vector
<
int
>
axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"axes"
));
std
::
vector
<
int
>
starts
=
...
...
@@ -43,82 +41,85 @@ class SliceOpConverter : public OpConverter {
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ends"
));
std
::
vector
<
int
>
decrease_axises
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"decrease_axis"
));
auto
input_dims
=
input
->
getDimensions
();
if
(
!
engine_
->
with_dynamic_shape
())
{
// notice that input shape is [CHW] without batch axis when input has
// static shape
for
(
size_t
i
=
input_dims
.
nbDims
;
i
>
0
;
i
--
)
{
input_dims
.
d
[
i
]
=
input_dims
.
d
[
i
-
1
];
}
input_dims
.
d
[
0
]
=
1
;
// fake batchsize, not useful here
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
if
(
starts
[
i
]
<
0
)
{
starts
[
i
]
=
std
::
max
(
starts
[
i
]
+
input_dims
.
d
[
axes
[
i
]],
0
);
}
if
(
ends
[
i
]
<
0
)
{
ends
[
i
]
=
std
::
max
(
ends
[
i
]
+
input_dims
.
d
[
axes
[
i
]],
0
);
}
ends
[
i
]
=
std
::
min
(
ends
[
i
],
input_dims
.
d
[
axes
[
i
]]);
PADDLE_ENFORCE_GT
(
ends
[
i
],
starts
[
i
],
platform
::
errors
::
InvalidArgument
(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d."
,
ends
[
i
],
starts
[
i
]));
}
}
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
#if IS_TRT_VERSION_GE(6000)
auto
nchw_input_dims
=
input
->
getDimensions
();
auto
*
shape_tensor
=
Shape
(
input
);
nvinfer1
::
Dims
trt_start_dims
;
trt_start_dims
.
nbDims
=
nchw_
input_dims
.
nbDims
;
memset
(
trt_start_dims
.
d
,
0
,
sizeof
(
int32_t
)
*
nchw_
input_dims
.
nbDims
);
trt_start_dims
.
nbDims
=
input_dims
.
nbDims
;
memset
(
trt_start_dims
.
d
,
0
,
sizeof
(
int32_t
)
*
input_dims
.
nbDims
);
nvinfer1
::
Dims
trt_size_dims
=
trt_start_dims
;
nvinfer1
::
Dims
trt_end_dims
=
trt_start_dims
;
nvinfer1
::
Dims
trt_step_dims
=
trt_start_dims
;
for
(
int
i
=
0
;
i
<
trt_step_dims
.
nbDims
;
i
++
)
trt_step_dims
.
d
[
i
]
=
1
;
// input : [N,C,H,W]
bool
has_neg_indices
=
false
;
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
];
trt_start_dims
.
d
[
trt_axis
]
=
starts
[
i
];
trt_end_dims
.
d
[
trt_axis
]
=
ends
[
i
];
if
(
starts
[
i
]
<
0
||
ends
[
i
]
<
0
)
has_neg_indices
=
true
;
nvinfer1
::
ITensor
*
start_tensor
=
nullptr
;
nvinfer1
::
ITensor
*
end_tensor
=
nullptr
;
std
::
vector
<
nvinfer1
::
ITensor
*>
starts_tensor
;
std
::
vector
<
nvinfer1
::
ITensor
*>
ends_tensor
;
for
(
int32_t
i
=
0
;
i
<
input_dims
.
nbDims
;
++
i
)
{
starts_tensor
.
push_back
(
Add1DConstantLayer
(
0
));
ends_tensor
.
push_back
(
GetEleTensorOfShape
(
shape_tensor
,
i
));
}
auto
slice_inputs
=
op_desc
.
Inputs
();
if
(
slice_inputs
.
find
(
"StartsTensor"
)
!=
slice_inputs
.
end
()
&&
op_desc
.
Input
(
"StartsTensor"
).
size
())
{
// has StartsTensor input
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
starts_tensor
[
axes
[
i
]]
=
GetEleTensorOfShape
(
engine_
->
GetITensor
(
op_desc
.
Input
(
"StartsTensor"
)[
0
]),
i
);
}
auto
*
shape_tensor
=
Shape
(
input
);
auto
*
start_tensor
=
Add1DConstantLayer
(
trt_start_dims
);
if
(
has_neg_indices
)
{
start_tensor
=
FixNegIndices
(
shape_tensor
,
start_tensor
);
}
else
{
PADDLE_ENFORCE_EQ
(
starts
.
size
(),
axes
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of this starts: %d must be "
"equal to the axes: %d."
,
starts
.
size
(),
axes
.
size
()));
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
// same as starts.size()
if
(
starts
[
i
]
<
0
)
{
starts_tensor
[
axes
[
i
]]
=
Max
(
Sum
(
Add1DConstantLayer
(
starts
[
i
]),
GetEleTensorOfShape
(
shape_tensor
,
axes
[
i
])),
Add1DConstantLayer
(
0
));
}
else
{
starts_tensor
[
axes
[
i
]]
=
Min
(
Add1DConstantLayer
(
starts
[
i
]),
GetEleTensorOfShape
(
shape_tensor
,
axes
[
i
]));
}
std
::
vector
<
nvinfer1
::
ITensor
*>
end_vec_tensor
;
for
(
int
i
=
0
;
i
<
trt_end_dims
.
nbDims
;
i
++
)
{
end_vec_tensor
.
push_back
(
GetEleTensorOfShape
(
shape_tensor
,
i
));
}
}
start_tensor
=
Concat
(
starts_tensor
);
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
trt_axis
=
axes
[
i
];
if
(
ends
[
i
]
>=
0
)
{
end_vec_tensor
[
trt_axis
]
=
Add1DConstantLayer
(
ends
[
i
]);
if
(
slice_inputs
.
find
(
"EndsTensor"
)
!=
slice_inputs
.
end
()
&&
op_desc
.
Input
(
"EndsTensor"
).
size
())
{
// has EndsTensor input
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
ends_tensor
[
axes
[
i
]]
=
GetEleTensorOfShape
(
engine_
->
GetITensor
(
op_desc
.
Input
(
"EndsTensor"
)[
0
]),
i
);
}
}
else
{
PADDLE_ENFORCE_EQ
(
ends
.
size
(),
axes
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of this ends: %d must be "
"equal to the axes: %d."
,
ends
.
size
(),
axes
.
size
()));
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
// same as ends.size()
if
(
ends
[
i
]
<
0
)
{
ends_tensor
[
axes
[
i
]]
=
Max
(
Sum
(
Add1DConstantLayer
(
ends
[
i
]),
GetEleTensorOfShape
(
shape_tensor
,
axes
[
i
])),
Add1DConstantLayer
(
0
));
}
else
{
end_vec_tensor
[
trt_axis
]
=
Sum
(
end_vec_tensor
[
trt_axis
],
Add1DConstantLayer
(
ends
[
i
]));
ends_tensor
[
axes
[
i
]]
=
Min
(
Add1DConstantLayer
(
ends
[
i
]),
GetEleTensorOfShape
(
shape_tensor
,
axes
[
i
]));
}
}
// CI failed in trt 6015 but success in 7134, may be a trt bug
#if IS_TRT_VERSION_GE(7134)
auto
*
size_tensor
=
Sub
(
Min
(
Concat
(
end_vec_tensor
),
shape_tensor
),
start_tensor
);
#else
auto
*
size_tensor
=
Sub
(
Concat
(
end_vec_tensor
),
start_tensor
);
#endif
}
end_tensor
=
Concat
(
ends_tensor
);
auto
*
size_tensor
=
Sub
(
end_tensor
,
start_tensor
);
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Slice
,
*
input
,
trt_start_dims
,
trt_size_dims
,
trt_step_dims
);
...
...
@@ -139,16 +140,30 @@ class SliceOpConverter : public OpConverter {
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
layer
->
getOutput
(
0
));
layer
->
setInput
(
1
,
*
real_size_tensor
);
}
#else
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
#endif
}
else
{
#if IS_TRT_VERSION_GE(6000)
// notice that input shape is [CHW] without batch axis when input has
// static shape
for
(
size_t
i
=
input_dims
.
nbDims
;
i
>
0
;
i
--
)
{
input_dims
.
d
[
i
]
=
input_dims
.
d
[
i
-
1
];
}
input_dims
.
d
[
0
]
=
1
;
// fake batchsize, not useful here
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
if
(
starts
[
i
]
<
0
)
{
starts
[
i
]
=
std
::
max
(
starts
[
i
]
+
input_dims
.
d
[
axes
[
i
]],
0
);
}
if
(
ends
[
i
]
<
0
)
{
ends
[
i
]
=
std
::
max
(
ends
[
i
]
+
input_dims
.
d
[
axes
[
i
]],
0
);
}
ends
[
i
]
=
std
::
min
(
ends
[
i
],
input_dims
.
d
[
axes
[
i
]]);
PADDLE_ENFORCE_GT
(
ends
[
i
],
starts
[
i
],
platform
::
errors
::
InvalidArgument
(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d."
,
ends
[
i
],
starts
[
i
]));
}
auto
chw_input_dims
=
input
->
getDimensions
();
nvinfer1
::
Dims
trt_start_dims
;
trt_start_dims
.
nbDims
=
chw_input_dims
.
nbDims
;
...
...
@@ -189,13 +204,6 @@ class SliceOpConverter : public OpConverter {
reshape_layer
->
setReshapeDimensions
(
real_trt_size_dims
);
layer
=
static_cast
<
nvinfer1
::
ILayer
*>
(
reshape_layer
);
}
#else
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePlugin
*
plugin
=
new
plugin
::
SlicePlugin
(
starts
,
ends
,
axes
,
with_fp16
);
layer
=
engine_
->
AddPlugin
(
&
input
,
1
,
plugin
);
#endif
}
RreplenishLayerAndOutput
(
layer
,
"slice"
,
{
output_name
},
test_mode
);
}
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
cdd7b956
...
...
@@ -1172,25 +1172,13 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
}
if
(
!
desc
.
HasAttr
(
"axes"
)
||
!
desc
.
HasAttr
(
"starts"
)
||
!
desc
.
HasAttr
(
"ends"
))
{
std
::
vector
<
int
>
axes
;
if
(
!
desc
.
HasAttr
(
"axes"
))
{
VLOG
(
3
)
<<
"The necessary attributes of the slice operator axes "
"
or starts or ends
are missing."
;
" are missing."
;
return
false
;
}
else
{
std
::
vector
<
int
>
axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"axes"
));
std
::
vector
<
int
>
starts
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"starts"
));
std
::
vector
<
int
>
ends
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"ends"
));
if
(
axes
.
size
()
!=
starts
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
{
VLOG
(
3
)
<<
"The shape of attributes of the slice operator axes "
"or starts or ends are not equal."
;
return
false
;
}
axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"axes"
));
if
(
!
with_dynamic_shape
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
if
(
axes
[
i
]
==
0
)
{
...
...
@@ -1203,16 +1191,44 @@ struct SimpleOpTypeSetTeller : public Teller {
}
// not support following four inputs for slice in paddle-trt
auto
slice_inputs
=
desc
.
Inputs
();
// its size == 5
if
(
slice_inputs
.
find
(
"StartsTensor"
)
!=
slice_inputs
.
end
())
{
if
(
desc
.
Input
(
"StartsTensor"
).
size
())
{
if
(
slice_inputs
.
find
(
"StartsTensor"
)
!=
slice_inputs
.
end
()
&&
desc
.
Input
(
"StartsTensor"
).
size
())
{
VLOG
(
3
)
<<
"The Slice has StartsTensor input."
;
}
else
{
if
(
!
desc
.
HasAttr
(
"starts"
))
{
VLOG
(
3
)
<<
"The necessary attributes of the slice operator starts or "
"StartsTensor"
" are missing."
;
return
false
;
}
else
{
std
::
vector
<
int
>
starts
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"starts"
));
if
(
axes
.
size
()
!=
starts
.
size
())
{
VLOG
(
3
)
<<
"The shape of attributes of the slice operator axes "
"and starts are not equal."
;
return
false
;
}
}
}
if
(
slice_inputs
.
find
(
"EndsTensor"
)
!=
slice_inputs
.
end
())
{
if
(
desc
.
Input
(
"EndsTensor"
).
size
())
{
if
(
slice_inputs
.
find
(
"EndsTensor"
)
!=
slice_inputs
.
end
()
&&
desc
.
Input
(
"EndsTensor"
).
size
())
{
VLOG
(
3
)
<<
"The Slice has EndsTensor input."
;
}
else
{
if
(
!
desc
.
HasAttr
(
"ends"
))
{
VLOG
(
3
)
<<
"The necessary attributes of the slice operator ends or "
"EndsTensor"
" are missing."
;
return
false
;
}
else
{
std
::
vector
<
int
>
ends
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"ends"
));
if
(
axes
.
size
()
!=
ends
.
size
())
{
VLOG
(
3
)
<<
"The shape of attributes of the slice operator axes "
"and ends are not equal."
;
return
false
;
}
}
}
if
(
slice_inputs
.
find
(
"StartsTensorList"
)
!=
slice_inputs
.
end
())
{
if
(
desc
.
Input
(
"StartsTensorList"
).
size
())
{
return
false
;
...
...
@@ -1833,10 +1849,6 @@ struct SimpleOpTypeSetTeller : public Teller {
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
*
x_var_desc
=
block
->
FindVar
(
x_var_name
);
const
auto
x_shape
=
x_var_desc
->
GetShape
();
if
(
x_shape
.
size
()
==
1
)
{
VLOG
(
3
)
<<
"clip op does not support input's dim is 1 in tensorrt."
;
return
false
;
}
}
if
(
op_type
==
"reduce_sum"
||
op_type
==
"reduce_mean"
)
{
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
cdd7b956
...
...
@@ -14,7 +14,6 @@ list(
emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu
skip_layernorm_op_plugin.cu
slice_op_plugin.cu
hard_swish_op_plugin.cu
stack_op_plugin.cu
anchor_generator_op_plugin.cu
...
...
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
已删除
100644 → 0
浏览文件 @
1aa64d13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
template
<
typename
T
>
__global__
void
SliceKernel
(
int
num
,
int
dims
,
const
T
*
input
,
const
int
*
offsets_info
,
T
*
output
)
{
const
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
extern
__shared__
int
shared_data
[];
for
(
int
i
=
threadIdx
.
x
;
i
<
dims
*
3
;
i
+=
blockDim
.
x
)
{
shared_data
[
i
]
=
offsets_info
[
i
];
}
__syncthreads
();
if
(
idx
<
num
)
{
int
t_idx
=
idx
;
int
in_idx
=
0
;
for
(
int
i
=
dims
-
1
;
i
>=
0
;
i
--
)
{
// output_shape
auto
t
=
t_idx
%
shared_data
[
i
*
3
+
1
];
// out offset
auto
s
=
t
+
shared_data
[
i
*
3
];
// input_seg_offset
in_idx
=
in_idx
+
shared_data
[
i
*
3
+
2
]
*
s
;
t_idx
=
t_idx
/
shared_data
[
i
*
3
+
1
];
}
output
[
idx
]
=
input
[
in_idx
];
}
}
SlicePlugin
::
SlicePlugin
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
)
{
with_fp16_
=
with_fp16
;
}
SlicePlugin
::
SlicePlugin
(
void
const
*
serial_data
,
size_t
serial_length
)
{
deserializeBase
(
serial_data
,
serial_length
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
starts_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
ends_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
axes_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
offset_info_
);
}
SlicePlugin
::~
SlicePlugin
()
{
cudaFree
(
offset_temp_data_
);
}
SlicePlugin
*
SlicePlugin
::
clone
()
const
TRT_NOEXCEPT
{
return
new
SlicePlugin
(
starts_
,
ends_
,
axes_
,
with_fp16_
);
}
bool
SlicePlugin
::
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
TRT_NOEXCEPT
{
if
(
with_fp16_
)
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kHALF
||
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
else
{
return
((
type
==
nvinfer1
::
DataType
::
kFLOAT
||
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
format
==
nvinfer1
::
PluginFormat
::
kLINEAR
));
}
}
nvinfer1
::
Dims
SlicePlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nb_input_dims
)
TRT_NOEXCEPT
{
auto
in_dims
=
inputs
[
0
];
nvinfer1
::
Dims
out_dims
=
in_dims
;
for
(
size_t
i
=
0
;
i
<
axes_
.
size
();
i
++
)
{
int
start
=
starts_
[
i
];
int
end
=
ends_
[
i
];
out_dims
.
d
[
axes_
[
i
]
-
1
]
=
end
-
start
;
}
return
out_dims
;
}
int
SlicePlugin
::
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
#if IS_TRT_VERSION_LT(8000)
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
#else
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
#endif
auto
input_dims
=
getInputDims
(
0
);
// notice input dims is [C, H, W], add input batch dim here
auto
out_dims
=
getOutputDimensions
(
0
,
&
input_dims
,
1
);
input_dims
.
nbDims
+=
1
;
out_dims
.
nbDims
+=
1
;
for
(
auto
i
=
input_dims
.
nbDims
;
i
>
0
;
--
i
)
{
input_dims
.
d
[
i
]
=
input_dims
.
d
[
i
-
1
];
out_dims
.
d
[
i
]
=
out_dims
.
d
[
i
-
1
];
}
input_dims
.
d
[
0
]
=
batch_size
;
out_dims
.
d
[
0
]
=
batch_size
;
auto
num_dims
=
input_dims
.
nbDims
;
size_t
out_num
=
ProductDim
(
out_dims
);
std
::
vector
<
int
>
seg_offsets
;
std
::
vector
<
int
>
offsets
;
std
::
vector
<
int
>
extends
;
offsets
.
resize
(
num_dims
);
extends
.
resize
(
num_dims
);
seg_offsets
.
resize
(
num_dims
);
seg_offsets
[
num_dims
-
1
]
=
1
;
for
(
int
i
=
num_dims
-
2
;
i
>=
0
;
i
--
)
{
seg_offsets
[
i
]
=
input_dims
.
d
[
i
+
1
]
*
seg_offsets
[
i
+
1
];
}
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
offsets
[
i
]
=
0
;
extends
[
i
]
=
out_dims
.
d
[
i
];
}
for
(
size_t
i
=
0
;
i
<
axes_
.
size
();
++
i
)
{
offsets
[
axes_
[
i
]]
=
starts_
[
i
];
}
std
::
vector
<
int
>
offset_info
;
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
offset_info
.
push_back
(
offsets
[
i
]);
offset_info
.
push_back
(
extends
[
i
]);
offset_info
.
push_back
(
seg_offsets
[
i
]);
}
if
(
offset_temp_data_
==
nullptr
)
{
cudaMalloc
(
&
offset_temp_data_
,
3
*
num_dims
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
int
threads
=
256
;
int
blocks
=
(
out_num
+
threads
-
1
)
/
threads
;
auto
input_type
=
getDataType
();
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->fp32"
;
const
float
*
input1
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
SliceKernel
<
float
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->fp16"
;
const
half
*
input1
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SliceKernel
<
half
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kINT32
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->int32"
;
const
int
*
input1
=
static_cast
<
const
int
*>
(
inputs
[
0
]);
int
*
output
=
static_cast
<
int
*>
(
outputs
[
0
]);
SliceKernel
<
int
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Slice TRT Plugin's input type should be float, half or int."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
size_t
SlicePlugin
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
return
getBaseSerializationSize
()
+
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
SerializedSize
(
axes_
)
+
SerializedSize
(
with_fp16_
)
+
SerializedSize
(
offset_info_
);
}
void
SlicePlugin
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
starts_
);
SerializeValue
(
&
buffer
,
ends_
);
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
offset_info_
);
}
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic
::
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
),
decrease_axis_
(
decrease_axis
)
{
with_fp16_
=
with_fp16
;
}
SlicePluginDynamic
::
SlicePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
starts_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ends_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
axes_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
decrease_axis_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
offset_info_
);
}
void
SlicePluginDynamic
::
destroy
()
TRT_NOEXCEPT
{
cudaFree
(
offset_temp_data_
);
delete
this
;
}
int
SlicePluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
size_t
SlicePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
size
=
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
SerializedSize
(
axes_
)
+
SerializedSize
(
decrease_axis_
)
+
SerializedSize
(
with_fp16_
)
+
SerializedSize
(
offset_info_
);
return
size
;
}
void
SlicePluginDynamic
::
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
{
SerializeValue
(
&
buffer
,
starts_
);
SerializeValue
(
&
buffer
,
ends_
);
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
decrease_axis_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
offset_info_
);
}
nvinfer1
::
DimsExprs
SlicePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
auto
in_dims
=
inputs
[
0
];
nvinfer1
::
DimsExprs
ret
=
in_dims
;
// start, ends should greater 0
for
(
size_t
i
=
0
;
i
<
axes_
.
size
();
i
++
)
{
int
start
=
starts_
[
i
];
int
end
=
ends_
[
i
];
#if IS_TRT_VERSION_GE(7200)
ret
.
d
[
axes_
[
i
]]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kMIN
,
*
expr_builder
.
constant
(
ends_
[
i
]),
*
in_dims
.
d
[
axes_
[
i
]]),
*
expr_builder
.
constant
(
start
));
#else
ret
.
d
[
axes_
[
i
]]
=
expr_builder
.
constant
(
end
-
start
);
#endif
}
if
(
decrease_axis_
!=
-
1
)
{
nvinfer1
::
DimsExprs
res
;
res
.
nbDims
=
ret
.
nbDims
-
1
;
int
j
=
0
;
for
(
size_t
i
=
0
;
i
<
in_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axis_
==
i
)
continue
;
res
.
d
[
j
++
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kMAX
,
*
expr_builder
.
constant
(
0
),
*
ret
.
d
[
i
]);
}
return
res
;
}
return
ret
;
}
bool
SlicePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
||
in
.
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kINT32
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
nvinfer1
::
DataType
SlicePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Slice Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kHALF
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kINT32
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input type should be half, float or int"
));
return
input_types
[
0
];
}
int
SlicePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
out_dims
=
output_desc
[
0
].
dims
;
if
(
decrease_axis_
!=
-
1
)
{
out_dims
=
input_dims
;
out_dims
.
d
[
decrease_axis_
]
=
1
;
}
auto
num_dims
=
input_dims
.
nbDims
;
size_t
out_num
=
ProductDim
(
out_dims
);
std
::
vector
<
int
>
seg_offsets
;
std
::
vector
<
int
>
offsets
;
std
::
vector
<
int
>
extends
;
offsets
.
resize
(
num_dims
);
extends
.
resize
(
num_dims
);
seg_offsets
.
resize
(
num_dims
);
seg_offsets
[
num_dims
-
1
]
=
1
;
for
(
int
i
=
num_dims
-
2
;
i
>=
0
;
i
--
)
{
seg_offsets
[
i
]
=
input_dims
.
d
[
i
+
1
]
*
seg_offsets
[
i
+
1
];
}
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
offsets
[
i
]
=
0
;
extends
[
i
]
=
out_dims
.
d
[
i
];
}
for
(
size_t
i
=
0
;
i
<
axes_
.
size
();
++
i
)
{
offsets
[
axes_
[
i
]]
=
starts_
[
i
];
}
offset_info_
.
resize
(
num_dims
*
3
);
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
offset_info_
[
i
*
3
+
0
]
=
offsets
[
i
];
offset_info_
[
i
*
3
+
1
]
=
extends
[
i
];
offset_info_
[
i
*
3
+
2
]
=
seg_offsets
[
i
];
}
if
(
offset_temp_data_
==
nullptr
)
{
cudaMalloc
(
&
offset_temp_data_
,
3
*
num_dims
*
sizeof
(
int
));
}
cudaMemcpyAsync
(
offset_temp_data_
,
offset_info_
.
data
(),
sizeof
(
int
)
*
3
*
num_dims
,
cudaMemcpyHostToDevice
,
stream
);
int
threads
=
256
;
int
blocks
=
(
out_num
+
threads
-
1
)
/
threads
;
auto
input_type
=
input_desc
[
0
].
type
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->fp32"
;
const
float
*
input1
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
SliceKernel
<
float
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->fp16"
;
const
half
*
input1
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SliceKernel
<
half
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kINT32
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. Slice-->int32"
;
const
int
*
input1
=
static_cast
<
const
int
*>
(
inputs
[
0
]);
int
*
output
=
static_cast
<
int
*>
(
outputs
[
0
]);
SliceKernel
<
int
><<<
blocks
,
threads
,
3
*
num_dims
*
sizeof
(
int
),
stream
>>>
(
out_num
,
num_dims
,
input1
,
offset_temp_data_
,
output
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Slice TRT Plugin's input type should be float, half or int."
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
已删除
100644 → 0
浏览文件 @
1aa64d13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
SlicePlugin
:
public
PluginTensorRT
{
public:
explicit
SlicePlugin
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
);
// It was used for tensorrt deserialization.
// It should not be called by users.
SlicePlugin
(
void
const
*
serial_data
,
size_t
serial_length
);
~
SlicePlugin
();
SlicePlugin
*
clone
()
const
TRT_NOEXCEPT
override
;
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"slice_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
{
return
0
;
}
bool
supportsFormat
(
nvinfer1
::
DataType
type
,
nvinfer1
::
PluginFormat
format
)
const
TRT_NOEXCEPT
override
;
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nb_input_dims
)
TRT_NOEXCEPT
override
;
#if IS_TRT_VERSION_LT(8000)
int
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
**
outputs
,
#else
int
enqueue
(
int
batch_size
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
#endif
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
private:
std
::
vector
<
int
>
starts_
;
std
::
vector
<
int
>
ends_
;
std
::
vector
<
int
>
axes_
;
int
*
offset_temp_data_
{
nullptr
};
std
::
vector
<
int
>
offset_info_
;
};
class
SlicePluginCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"slice_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
SlicePlugin
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
SlicePluginCreator
);
#if IS_TRT_VERSION_GE(6000)
class
SlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
bool
with_fp16
);
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
SlicePluginDynamic
(
starts_
,
ends_
,
axes_
,
decrease_axis_
,
with_fp16_
);
}
SlicePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"slice_plugin_dynamic"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
;
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
;
private:
std
::
vector
<
int
>
starts_
;
std
::
vector
<
int
>
ends_
;
std
::
vector
<
int
>
axes_
;
int
decrease_axis_
;
int
*
offset_temp_data_
{
nullptr
};
std
::
vector
<
int
>
offset_info_
;
};
class
SlicePluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"slice_plugin_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serialData
,
size_t
serialLength
)
TRT_NOEXCEPT
override
{
return
new
SlicePluginDynamic
(
serialData
,
serialLength
);
}
};
REGISTER_TRT_PLUGIN_V2
(
SlicePluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_clip.py
浏览文件 @
cdd7b956
...
...
@@ -120,7 +120,10 @@ class TrtConvertClipTest(TrtLayerAutoScanTest):
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
if
self
.
input_num
==
3
or
self
.
dims
==
1
:
if
self
.
input_num
==
3
:
return
0
,
3
else
:
if
not
dynamic_shape
and
self
.
dims
==
1
:
return
0
,
3
else
:
return
1
,
2
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录