Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
20c3224d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
20c3224d
编写于
11月 28, 2022
作者:
X
xiaoxiaohehe001
提交者:
GitHub
11月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] Add gather_nd trt converter. (#47589)
* add_gather_nd_ * add_gather_nd_ * add_gather_nd_
上级
827fd5cd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
36 addition
and
22 deletion
+36
-22
paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc
paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc
+15
-4
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+5
-2
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py
...ests/unittests/ir/inference/test_trt_convert_gather_nd.py
+16
-16
未找到文件。
paddle/fluid/inference/tensorrt/convert/gather_nd_op.cc
浏览文件 @
20c3224d
...
...
@@ -24,13 +24,24 @@ class GatherNdOpConverter : public OpConverter {
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a paddle gather_nd op to tensorrt gather_nd plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
index
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Index"
)[
0
]);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
// AddGatherV2 is supported by the trt version of 8.2.
#if IS_TRT_VERSION_GE(8200)
VLOG
(
3
)
<<
"convert gather_nd op to tensorrt gather_nd layer"
;
auto
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
GatherV2
,
*
input
,
*
index
,
nvinfer1
::
GatherMode
::
kND
);
layer
->
setNbElementWiseDims
(
0
);
RreplenishLayerAndOutput
(
layer
,
"gather_nd"
,
{
output_name
},
test_mode
);
#else
VLOG
(
4
)
<<
"convert a paddle gather_nd op to tensorrt gather_nd plugin"
;
// Declare inputs
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs
;
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
index
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Index"
)[
0
]);
inputs
.
emplace_back
(
input
);
inputs
.
emplace_back
(
index
);
...
...
@@ -41,7 +52,6 @@ class GatherNdOpConverter : public OpConverter {
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
inputs
.
size
(),
plugin
);
std
::
string
layer_name
=
"gather_nd (Output: "
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
layer_name
+=
output_name
;
...
...
@@ -49,6 +59,7 @@ class GatherNdOpConverter : public OpConverter {
engine_
->
DeclareOutput
(
output_name
);
}
layer
->
setName
((
layer_name
+
")"
).
c_str
());
#endif
}
};
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
20c3224d
...
...
@@ -566,9 +566,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass."
;
return
false
;
}
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
index_var_name
=
desc
.
Input
(
"Index"
)[
0
];
auto
*
x_var_desc
=
block
->
FindVar
(
x_var_name
);
auto
*
index_var_desc
=
block
->
FindVar
(
index_var_name
);
// The index input must be int32 datatype.
...
...
@@ -578,6 +577,9 @@ struct SimpleOpTypeSetTeller : public Teller {
return
false
;
}
#if IS_TRT_VERSION_LT(8200)
auto
x_var_name
=
desc
.
Input
(
"X"
)[
0
];
auto
*
x_var_desc
=
block
->
FindVar
(
x_var_name
);
const
auto
index_shape
=
index_var_desc
->
GetShape
();
const
auto
x_shape
=
x_var_desc
->
GetShape
();
if
(
x_shape
.
size
()
<=
2
)
{
...
...
@@ -591,6 +593,7 @@ struct SimpleOpTypeSetTeller : public Teller {
<<
" ] not equal to x dims size ["
<<
x_shape
.
size
()
<<
"]"
;
return
false
;
}
#endif
}
if
(
op_type
==
"anchor_generator"
)
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py
浏览文件 @
20c3224d
...
...
@@ -69,11 +69,11 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
8
,
8
,
8
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
1
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
32
,
64
,
64
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
1
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
...
...
@@ -159,11 +159,11 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest):
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
8
,
8
,
8
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
2
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
32
,
64
,
64
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
2
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
...
...
@@ -249,11 +249,11 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest):
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
8
,
8
,
8
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
2
,
2
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
32
,
64
,
64
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
2
,
2
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
...
...
@@ -339,11 +339,11 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest):
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
8
,
8
,
8
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
2
,
2
,
4
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
32
,
64
,
64
],
"input_data"
:
[
2
,
32
,
64
,
64
],
"index_data"
:
[
2
,
2
,
4
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
...
...
@@ -429,15 +429,15 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
4
],
"input_data"
:
[
2
,
32
],
"index_data"
:
[
2
,
2
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
4
,
64
],
"input_data"
:
[
2
,
32
],
"index_data"
:
[
2
,
2
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
2
,
8
],
"input_data"
:
[
2
,
32
],
"index_data"
:
[
2
,
2
],
}
...
...
@@ -521,15 +521,15 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"input_data"
:
[
1
,
4
,
4
],
"index_data"
:
[
1
,
1
,
1
],
"input_data"
:
[
1
6
,
32
,
256
],
"index_data"
:
[
2
,
2
,
2
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"input_data"
:
[
16
,
64
,
512
],
"index_data"
:
[
4
,
2
,
4
],
"input_data"
:
[
16
,
32
,
256
],
"index_data"
:
[
2
,
2
,
2
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data"
:
[
2
,
8
,
64
],
"input_data"
:
[
16
,
32
,
256
],
"index_data"
:
[
2
,
2
,
2
],
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录