Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7ec1e9af
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7ec1e9af
编写于
4月 19, 2022
作者:
F
feng_shuai
提交者:
GitHub
4月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add trt supoort for slice op (#41467) (#41911)
上级
15d30815
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
62 addition
and
27 deletion
+62
-27
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+6
-2
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+17
-12
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
+22
-3
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
+5
-2
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
...id/tests/unittests/ir/inference/test_trt_convert_slice.py
+12
-8
未找到文件。
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
7ec1e9af
...
@@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter {
...
@@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter {
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"starts"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"starts"
));
std
::
vector
<
int
>
ends
=
std
::
vector
<
int
>
ends
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ends"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ends"
));
std
::
vector
<
int
>
decrease_axises
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"decrease_axis"
));
auto
input_dims
=
input
->
getDimensions
();
auto
input_dims
=
input
->
getDimensions
();
if
(
!
engine_
->
with_dynamic_shape
())
{
if
(
!
engine_
->
with_dynamic_shape
())
{
...
@@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter {
...
@@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter {
}
else
{
}
else
{
bool
with_fp16
=
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePluginDynamic
*
plugin
=
int
decrease_axis
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
with_fp16
);
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
}
}
}
else
{
}
else
{
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
7ec1e9af
...
@@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if
(
desc
.
HasAttr
(
"decrease_axis"
))
{
if
(
desc
.
HasAttr
(
"decrease_axis"
))
{
std
::
vector
<
int
>
decrease_axis
=
std
::
vector
<
int
>
decrease_axis
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"decrease_axis"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"decrease_axis"
));
if
(
decrease_axis
.
size
()
>
0
)
{
if
(
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"Invalid slice decrease_axis. decrease_axis.size() > 0"
if
(
decrease_axis
.
size
()
>
1
)
{
"is not supported in TensorRT"
;
return
false
;
return
false
;
}
}
else
{
if
(
decrease_axis
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT"
;
return
false
;
}
}
}
}
}
...
@@ -1054,17 +1060,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1054,17 +1060,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return
false
;
return
false
;
}
}
if
(
desc
.
Input
(
"Ids"
).
size
()
!=
desc
.
Input
(
"Embs"
).
size
())
{
if
(
desc
.
Input
(
"Ids"
).
size
()
!=
desc
.
Input
(
"Embs"
).
size
())
{
VLOG
(
3
)
<<
"The id and emb size of fused EmbEltwiseLayerNormOp "
"should be same "
;
return
false
;
return
false
;
}
}
}
}
if
(
op_type
==
"fused_preln_embedding_eltwise_layernorm"
)
{
if
(
op_type
==
"fused_preln_embedding_eltwise_layernorm"
)
{
if
(
!
with_dynamic_shape
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
VLOG
(
3
)
<<
"fused_preln_embedding_eltwise_layernorm should run on "
<<
"fused_preln_embedding_eltwise_layernorm should run on
dynamic "
"
dynamic "
"shape mode."
;
"shape mode."
;
return
false
;
return
false
;
}
}
if
(
desc
.
Input
(
"Ids"
).
size
()
!=
desc
.
Input
(
"Embs"
).
size
())
{
if
(
desc
.
Input
(
"Ids"
).
size
()
!=
desc
.
Input
(
"Embs"
).
size
())
{
...
@@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
const auto y_shape = y_var_desc->GetShape();
const auto y_shape = y_var_desc->GetShape();
if (y_shape.size() != 2) {
if (y_shape.size() != 2) {
VLOG(3)
VLOG(3)
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = "
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes =
"
<< y_shape.size();
<< y_shape.size();
return false;
return false;
}
}
...
@@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
#else
#else
if
(
dtype
!=
framework
::
proto
::
VarType
::
FP32
)
{
if
(
dtype
!=
framework
::
proto
::
VarType
::
FP32
)
{
VLOG
(
3
)
VLOG
(
3
)
<<
"reduce op input data type must be float32 using TensorRT "
<<
"reduce op input data type must be float32 using TensorRT
< 7.0"
;
"
< 7.0"
;
return
false
;
return
false
;
}
}
#endif
#endif
...
...
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
浏览文件 @
7ec1e9af
...
@@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
...
@@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic
::
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
SlicePluginDynamic
::
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
)
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
)
{
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
),
decrease_axis_
(
decrease_axis
)
{
with_fp16_
=
with_fp16
;
with_fp16_
=
with_fp16
;
cudaEventCreate
(
&
copy_event_
);
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
cudaStreamCreate
(
&
copy_stream_
);
...
@@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
...
@@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
starts_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
starts_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ends_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ends_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
axes_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
axes_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
decrease_axis_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
cudaEventCreate
(
&
copy_event_
);
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
cudaStreamCreate
(
&
copy_stream_
);
...
@@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
...
@@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t
SlicePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
SlicePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
size
=
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
size_t
size
=
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
SerializedSize
(
axes_
)
+
SerializedSize
(
with_fp16_
);
SerializedSize
(
axes_
)
+
SerializedSize
(
decrease_axis_
)
+
SerializedSize
(
with_fp16_
);
return
size
;
return
size
;
}
}
...
@@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
...
@@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
starts_
);
SerializeValue
(
&
buffer
,
starts_
);
SerializeValue
(
&
buffer
,
ends_
);
SerializeValue
(
&
buffer
,
ends_
);
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
decrease_axis_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
}
...
@@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
...
@@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
ret
.
d
[
axes_
[
i
]]
=
expr_builder
.
constant
(
end
-
start
);
ret
.
d
[
axes_
[
i
]]
=
expr_builder
.
constant
(
end
-
start
);
#endif
#endif
}
}
if
(
decrease_axis_
!=
-
1
)
{
nvinfer1
::
DimsExprs
res
;
res
.
nbDims
=
ret
.
nbDims
-
1
;
int
j
=
0
;
for
(
size_t
i
=
0
;
i
<
in_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axis_
==
i
)
continue
;
res
.
d
[
j
++
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kMAX
,
*
expr_builder
.
constant
(
0
),
*
ret
.
d
[
i
]);
}
return
res
;
}
return
ret
;
return
ret
;
}
}
...
@@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
...
@@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
out_dims
=
output_desc
[
0
].
dims
;
auto
out_dims
=
output_desc
[
0
].
dims
;
if
(
decrease_axis_
!=
-
1
)
{
out_dims
=
input_dims
;
out_dims
.
d
[
decrease_axis_
]
=
1
;
}
auto
num_dims
=
input_dims
.
nbDims
;
auto
num_dims
=
input_dims
.
nbDims
;
size_t
out_num
=
ProductDim
(
out_dims
);
size_t
out_num
=
ProductDim
(
out_dims
);
...
...
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
浏览文件 @
7ec1e9af
...
@@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
...
@@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
class
SlicePluginDynamic
:
public
DynamicPluginTensorRT
{
class
SlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
explicit
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
explicit
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
);
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
bool
with_fp16
);
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
SlicePluginDynamic
(
starts_
,
ends_
,
axes_
,
with_fp16_
);
return
new
SlicePluginDynamic
(
starts_
,
ends_
,
axes_
,
decrease_axis_
,
with_fp16_
);
}
}
SlicePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
SlicePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
...
@@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
...
@@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std
::
vector
<
int
>
starts_
;
std
::
vector
<
int
>
starts_
;
std
::
vector
<
int
>
ends_
;
std
::
vector
<
int
>
ends_
;
std
::
vector
<
int
>
axes_
;
std
::
vector
<
int
>
axes_
;
int
decrease_axis_
;
int
*
offset_temp_data_
{
nullptr
};
int
*
offset_temp_data_
{
nullptr
};
cudaEvent_t
copy_event_
;
cudaEvent_t
copy_event_
;
cudaStream_t
copy_stream_
;
cudaStream_t
copy_stream_
;
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
浏览文件 @
7ec1e9af
...
@@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def
sample_program_configs
(
self
):
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
ones
([
1
,
3
,
64
,
64
]).
astype
(
np
.
float32
)
return
np
.
ones
([
6
,
6
,
64
,
64
]).
astype
(
np
.
float32
)
for
axes
in
[[
0
,
1
],
[
1
,
3
],
[
2
,
3
]]:
for
axes
in
[[
0
,
1
],
[
1
,
3
],
[
2
,
3
]]:
for
starts
in
[[
0
,
1
]
,
[
-
4
,
-
3
]
]:
for
starts
in
[[
0
,
1
]]:
for
ends
in
[[
2
,
2
],
[
-
1
,
-
2
],
[
5
,
5
]]:
for
ends
in
[[
2
,
2
],
[
5
,
5
]]:
for
decrease_axis
in
[[],
[
1
],
[
2
],
[
-
1
],
[
-
100
]]:
for
decrease_axis
in
[[],
[
1
],
[
2
],
[
-
1
],
[
-
100
]]:
for
infer_flags
in
[[
-
1
]]:
for
infer_flags
in
[[
-
1
]]:
dics
=
[{
dics
=
[{
...
@@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
]}
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
3
,
64
,
64
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
8
,
8
,
64
,
64
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
1
,
3
,
64
,
64
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
6
,
6
,
64
,
64
]}
def
clear_dynamic_shape
():
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
min_input_shape
=
{}
...
@@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
inputs
=
program_config
.
inputs
inputs
=
program_config
.
inputs
if
len
(
attrs
[
0
][
"decrease_axis"
])
!=
0
:
if
dynamic_shape
==
True
and
len
(
attrs
[
0
][
"decrease_axis"
])
==
0
:
return
1
,
2
if
dynamic_shape
==
True
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
1
:
return
0
,
3
if
dynamic_shape
==
False
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
0
:
return
0
,
3
return
0
,
3
if
dynamic_shape
:
if
dynamic_shape
:
for
i
in
range
(
len
(
attrs
[
0
][
"starts"
])):
for
i
in
range
(
len
(
attrs
[
0
][
"starts"
])):
...
@@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
program_config
.
ops
[
i
].
attrs
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
for
i
in
range
(
len
(
program_config
.
ops
))
]
]
self
.
trt_param
.
max_batch_size
=
9
# for static_shape
# for static_shape
clear_dynamic_shape
()
clear_dynamic_shape
()
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
...
@@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
...
@@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
# TODO(inference): fix.
# TODO(inference): fix.
# trt6 and trt7.1 has bug.
# trt6 and trt7.1 has bug.
# trt7.2 deserialize has bug.
# trt7.2 deserialize has bug.
#
self.run_test()
self
.
run_test
()
pass
pass
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录