Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e568268b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e568268b
编写于
4月 19, 2022
作者:
J
JingZhuangzhuang
提交者:
GitHub
4月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix_poo2d_trt_convert (#41860) (#41915)
上级
aa6eb0e8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
44 addition
and
14 deletion
+44
-14
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+44
-14
未找到文件。
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
浏览文件 @
e568268b
...
@@ -256,19 +256,51 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -256,19 +256,51 @@ class Pool2dOpConverter : public OpConverter {
if
(
!
adaptive
)
{
if
(
!
adaptive
)
{
if
(
ceil_mode
)
{
if
(
ceil_mode
)
{
std
::
vector
<
int
>
input_shape_v
;
if
(
nv_ksize
.
d
[
0
]
%
nv_strides
.
d
[
0
]
==
0
&&
for
(
int
i
=
0
;
i
<
input_dims
;
i
++
)
{
nv_ksize
.
d
[
1
]
%
nv_strides
.
d
[
1
]
==
0
)
{
input_shape_v
.
push_back
(
input_shape
.
d
[
i
]);
nvinfer1
::
DimsHW
pre_pad
(
0
,
0
);
nvinfer1
::
DimsHW
post_pad
(
0
,
0
);
// If ceil mode is true, we will pad the appropriate size to the
// input.
DealCeilMode
(
input_shape
,
ksize
,
strides
,
paddings
,
&
pre_pad
,
&
post_pad
,
input_dims
);
auto
*
pad_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Padding
,
*
input1
,
pre_pad
,
post_pad
);
PADDLE_ENFORCE_NOT_NULL
(
pad_layer
,
platform
::
errors
::
Fatal
(
"Pad layer in poolOp converter could not be "
"created. The pointer to pad layer is `NULL`."
));
input1
=
pad_layer
->
getOutput
(
0
);
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
input1
,
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
platform
::
errors
::
Fatal
(
"trt pool layer in converter could not be created."
));
pool_layer
->
setStride
(
nv_strides
);
pool_layer
->
setPadding
(
nv_paddings
);
if
(
padding_algorithm
==
"SAME"
)
{
pool_layer
->
setPaddingMode
(
nvinfer1
::
PaddingMode
::
kSAME_UPPER
);
}
pool_layer
->
setAverageCountExcludesPadding
(
exclusive
);
layer
=
pool_layer
;
}
else
{
std
::
vector
<
int
>
input_shape_v
;
for
(
int
i
=
0
;
i
<
input_dims
;
i
++
)
{
input_shape_v
.
push_back
(
input_shape
.
d
[
i
]);
}
plugin
::
PoolPlugin
*
plugin
=
new
plugin
::
PoolPlugin
(
ceil_mode
,
plugin_pool_type
,
adaptive
,
exclusive
,
ksize
,
strides
,
paddings
,
input_shape_v
,
real_paddings
);
auto
*
pool_layer
=
engine_
->
AddPlugin
(
&
input1
,
1
,
plugin
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
platform
::
errors
::
Fatal
(
"trt pool plugin layer in converter could not be created."
));
layer
=
pool_layer
;
}
}
plugin
::
PoolPlugin
*
plugin
=
new
plugin
::
PoolPlugin
(
ceil_mode
,
plugin_pool_type
,
adaptive
,
exclusive
,
ksize
,
strides
,
paddings
,
input_shape_v
,
real_paddings
);
auto
*
pool_layer
=
engine_
->
AddPlugin
(
&
input1
,
1
,
plugin
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
platform
::
errors
::
Fatal
(
"trt pool plugin layer in converter could not be created."
));
layer
=
pool_layer
;
}
else
{
}
else
{
#if IS_TRT_VERSION_GE(8000)
#if IS_TRT_VERSION_GE(8000)
// Exclude padding pixels from the average mean is not supported well by
// Exclude padding pixels from the average mean is not supported well by
...
@@ -299,7 +331,6 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -299,7 +331,6 @@ class Pool2dOpConverter : public OpConverter {
pool_layer
->
setAverageCountExcludesPadding
(
exclusive
);
pool_layer
->
setAverageCountExcludesPadding
(
exclusive
);
layer
=
pool_layer
;
layer
=
pool_layer
;
}
}
}
else
{
}
else
{
// Average pooling needs to exclude the padding pixels from the average
// Average pooling needs to exclude the padding pixels from the average
// mean.
// mean.
...
@@ -327,5 +358,4 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -327,5 +358,4 @@ class Pool2dOpConverter : public OpConverter {
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
USE_OP_ITSELF
(
pool2d
);
REGISTER_TRT_OP_CONVERTER
(
pool2d
,
Pool2dOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
pool2d
,
Pool2dOpConverter
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录