Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7ec1e9af
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7ec1e9af
编写于
4月 19, 2022
作者:
F
feng_shuai
提交者:
GitHub
4月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add trt supoort for slice op (#41467) (#41911)
上级
15d30815
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
62 addition
and
27 deletion
+62
-27
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+6
-2
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+17
-12
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
+22
-3
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
+5
-2
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
...id/tests/unittests/ir/inference/test_trt_convert_slice.py
+12
-8
未找到文件。
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
7ec1e9af
...
...
@@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter {
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"starts"
));
std
::
vector
<
int
>
ends
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"ends"
));
std
::
vector
<
int
>
decrease_axises
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"decrease_axis"
));
auto
input_dims
=
input
->
getDimensions
();
if
(
!
engine_
->
with_dynamic_shape
())
{
...
...
@@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter {
}
else
{
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
with_fp16
);
int
decrease_axis
=
decrease_axises
.
size
()
==
0
?
-
1
:
decrease_axises
[
0
];
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
decrease_axis
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
1
,
plugin
);
}
}
else
{
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
7ec1e9af
...
...
@@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if
(
desc
.
HasAttr
(
"decrease_axis"
))
{
std
::
vector
<
int
>
decrease_axis
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"decrease_axis"
));
if
(
decrease_axis
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT"
;
return
false
;
if
(
with_dynamic_shape
)
{
if
(
decrease_axis
.
size
()
>
1
)
{
return
false
;
}
}
else
{
if
(
decrease_axis
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT"
;
return
false
;
}
}
}
...
...
@@ -1054,17 +1060,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return
false
;
}
if
(
desc
.
Input
(
"Ids"
).
size
()
!=
desc
.
Input
(
"Embs"
).
size
())
{
VLOG
(
3
)
<<
"The id and emb size of fused EmbEltwiseLayerNormOp "
"should be same "
;
return
false
;
}
}
if
(
op_type
==
"fused_preln_embedding_eltwise_layernorm"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"fused_preln_embedding_eltwise_layernorm should run on
dynamic "
"shape mode."
;
VLOG
(
3
)
<<
"fused_preln_embedding_eltwise_layernorm should run on "
"
dynamic "
"shape mode."
;
return
false
;
}
if
(
desc
.
Input
(
"Ids"
).
size
()
!=
desc
.
Input
(
"Embs"
).
size
())
{
...
...
@@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
const auto y_shape = y_var_desc->GetShape();
if (y_shape.size() != 2) {
VLOG(3)
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = "
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes =
"
<< y_shape.size();
return false;
}
...
...
@@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
#else
if
(
dtype
!=
framework
::
proto
::
VarType
::
FP32
)
{
VLOG
(
3
)
<<
"reduce op input data type must be float32 using TensorRT
< 7.0"
;
VLOG
(
3
)
<<
"reduce op input data type must be float32 using TensorRT "
"
< 7.0"
;
return
false
;
}
#endif
...
...
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
浏览文件 @
7ec1e9af
...
...
@@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic
::
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
)
{
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
bool
with_fp16
)
:
starts_
(
starts
),
ends_
(
ends
),
axes_
(
axes
),
decrease_axis_
(
decrease_axis
)
{
with_fp16_
=
with_fp16
;
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
...
...
@@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
starts_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ends_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
axes_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
decrease_axis_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
cudaEventCreate
(
&
copy_event_
);
cudaStreamCreate
(
&
copy_stream_
);
...
...
@@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t
SlicePluginDynamic
::
getSerializationSize
()
const
TRT_NOEXCEPT
{
size_t
size
=
SerializedSize
(
starts_
)
+
SerializedSize
(
ends_
)
+
SerializedSize
(
axes_
)
+
SerializedSize
(
with_fp16_
);
SerializedSize
(
axes_
)
+
SerializedSize
(
decrease_axis_
)
+
SerializedSize
(
with_fp16_
);
return
size
;
}
...
...
@@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
starts_
);
SerializeValue
(
&
buffer
,
ends_
);
SerializeValue
(
&
buffer
,
axes_
);
SerializeValue
(
&
buffer
,
decrease_axis_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
...
...
@@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
ret
.
d
[
axes_
[
i
]]
=
expr_builder
.
constant
(
end
-
start
);
#endif
}
if
(
decrease_axis_
!=
-
1
)
{
nvinfer1
::
DimsExprs
res
;
res
.
nbDims
=
ret
.
nbDims
-
1
;
int
j
=
0
;
for
(
size_t
i
=
0
;
i
<
in_dims
.
nbDims
;
i
++
)
{
if
(
decrease_axis_
==
i
)
continue
;
res
.
d
[
j
++
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kMAX
,
*
expr_builder
.
constant
(
0
),
*
ret
.
d
[
i
]);
}
return
res
;
}
return
ret
;
}
...
...
@@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
out_dims
=
output_desc
[
0
].
dims
;
if
(
decrease_axis_
!=
-
1
)
{
out_dims
=
input_dims
;
out_dims
.
d
[
decrease_axis_
]
=
1
;
}
auto
num_dims
=
input_dims
.
nbDims
;
size_t
out_num
=
ProductDim
(
out_dims
);
...
...
paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
浏览文件 @
7ec1e9af
...
...
@@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
class
SlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
SlicePluginDynamic
(
std
::
vector
<
int
>
starts
,
std
::
vector
<
int
>
ends
,
std
::
vector
<
int
>
axes
,
bool
with_fp16
);
std
::
vector
<
int
>
axes
,
int
decrease_axis
,
bool
with_fp16
);
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
SlicePluginDynamic
(
starts_
,
ends_
,
axes_
,
with_fp16_
);
return
new
SlicePluginDynamic
(
starts_
,
ends_
,
axes_
,
decrease_axis_
,
with_fp16_
);
}
SlicePluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
...
...
@@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std
::
vector
<
int
>
starts_
;
std
::
vector
<
int
>
ends_
;
std
::
vector
<
int
>
axes_
;
int
decrease_axis_
;
int
*
offset_temp_data_
{
nullptr
};
cudaEvent_t
copy_event_
;
cudaStream_t
copy_stream_
;
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py
浏览文件 @
7ec1e9af
...
...
@@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
ones
([
1
,
3
,
64
,
64
]).
astype
(
np
.
float32
)
return
np
.
ones
([
6
,
6
,
64
,
64
]).
astype
(
np
.
float32
)
for
axes
in
[[
0
,
1
],
[
1
,
3
],
[
2
,
3
]]:
for
starts
in
[[
0
,
1
]
,
[
-
4
,
-
3
]
]:
for
ends
in
[[
2
,
2
],
[
-
1
,
-
2
],
[
5
,
5
]]:
for
starts
in
[[
0
,
1
]]:
for
ends
in
[[
2
,
2
],
[
5
,
5
]]:
for
decrease_axis
in
[[],
[
1
],
[
2
],
[
-
1
],
[
-
100
]]:
for
infer_flags
in
[[
-
1
]]:
dics
=
[{
...
...
@@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
3
,
32
,
32
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
3
,
64
,
64
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
1
,
3
,
64
,
64
]}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
8
,
8
,
64
,
64
]}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
6
,
6
,
64
,
64
]}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
...
...
@@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
inputs
=
program_config
.
inputs
if
len
(
attrs
[
0
][
"decrease_axis"
])
!=
0
:
if
dynamic_shape
==
True
and
len
(
attrs
[
0
][
"decrease_axis"
])
==
0
:
return
1
,
2
if
dynamic_shape
==
True
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
1
:
return
0
,
3
if
dynamic_shape
==
False
and
len
(
attrs
[
0
][
"decrease_axis"
])
!=
0
:
return
0
,
3
if
dynamic_shape
:
for
i
in
range
(
len
(
attrs
[
0
][
"starts"
])):
...
...
@@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
self
.
trt_param
.
max_batch_size
=
9
# for static_shape
clear_dynamic_shape
()
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
...
...
@@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
# TODO(inference): fix.
# trt6 and trt7.1 has bug.
# trt7.2 deserialize has bug.
#
self.run_test()
self
.
run_test
()
pass
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录