Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
32211fe9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
32211fe9
编写于
3月 03, 2021
作者:
P
Pei Yang
提交者:
GitHub
3月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
TRT conv2d converter support SAME padding (#31379)
上级
e312a1ff
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
10 addition
and
14 deletion
+10
-14
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+7
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+1
-7
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py
.../fluid/tests/unittests/ir/inference/test_trt_conv_pass.py
+2
-7
未找到文件。
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
32211fe9
...
@@ -97,6 +97,10 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
...
@@ -97,6 +97,10 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"strides"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"strides"
));
const
std
::
vector
<
int
>
paddings
=
const
std
::
vector
<
int
>
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"paddings"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"paddings"
));
std
::
string
padding_algorithm
=
"EXPLICIT"
;
if
(
op_desc
.
HasAttr
(
"padding_algorithm"
))
padding_algorithm
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"padding_algorithm"
));
nvinfer1
::
DimsHW
nv_ksize
(
filter_h
,
filter_w
);
nvinfer1
::
DimsHW
nv_ksize
(
filter_h
,
filter_w
);
nvinfer1
::
DimsHW
nv_dilations
(
dilations
[
0
],
dilations
[
1
]);
nvinfer1
::
DimsHW
nv_dilations
(
dilations
[
0
],
dilations
[
1
]);
...
@@ -126,6 +130,9 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
...
@@ -126,6 +130,9 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer
->
setStride
(
nv_strides
);
layer
->
setStride
(
nv_strides
);
layer
->
setPadding
(
nv_paddings
);
layer
->
setPadding
(
nv_paddings
);
layer
->
setNbGroups
(
groups
);
layer
->
setNbGroups
(
groups
);
if
(
padding_algorithm
==
"SAME"
)
{
layer
->
setPaddingMode
(
nvinfer1
::
PaddingMode
::
kSAME_UPPER
);
}
// set dilations
// set dilations
fset_dilation
(
layer
,
nv_dilations
);
fset_dilation
(
layer
,
nv_dilations
);
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
32211fe9
...
@@ -129,13 +129,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -129,13 +129,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
std
::
vector
<
int
>
paddings
=
std
::
vector
<
int
>
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"paddings"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"paddings"
));
std
::
string
padding_algorithm
=
"EXPLICIT"
;
if
(
paddings
.
size
()
>
2
)
return
false
;
if
(
desc
.
HasAttr
(
"padding_algorithm"
))
padding_algorithm
=
BOOST_GET_CONST
(
std
::
string
,
desc
.
GetAttr
(
"padding_algorithm"
));
if
(
paddings
.
size
()
>
2
||
(
padding_algorithm
==
"SAME"
&&
op_type
!=
"pool2d"
))
return
false
;
}
}
if
(
op_type
==
"matmul"
)
{
if
(
op_type
==
"matmul"
)
{
auto
*
block
=
desc
.
Block
();
auto
*
block
=
desc
.
Block
();
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py
浏览文件 @
32211fe9
...
@@ -67,15 +67,12 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest):
...
@@ -67,15 +67,12 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest):
self
.
conv_padding
=
'VALID'
self
.
conv_padding
=
'VALID'
'''
# conv2d padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
class
TensorRTSubgraphPassConvSamePaddingTest
(
InferencePassTest
):
class
TensorRTSubgraphPassConvSamePaddingTest
(
InferencePassTest
):
def
set_params
(
self
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
3
self
.
conv_groups
=
3
self
.
conv_padding
=
'SAME'
self
.
conv_padding
=
'SAME'
'''
class
TensorRTSubgraphPassDepthwiseConvTest
(
TensorRTSubgraphPassConvTest
):
class
TensorRTSubgraphPassDepthwiseConvTest
(
TensorRTSubgraphPassConvTest
):
...
@@ -131,15 +128,13 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest(
...
@@ -131,15 +128,13 @@ class TensorRTSubgraphPassConvTransposeValidPaddingTest(
self
.
conv_padding
=
'VALID'
self
.
conv_padding
=
'VALID'
'''
class
TensorRTSubgraphPassConvTransposeSamePaddingTest
(
# conv2d_transpose padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
TensorRTSubgraphPassConvTransposeTest
):
class TensorRTSubgraphPassConvTransposeSamePaddingTest(TensorRTSubgraphPassConvTransposeTest):
def
set_params
(
self
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
1
self
.
conv_groups
=
1
self
.
conv_padding
=
'SAME'
self
.
conv_padding
=
'SAME'
'''
class
TensorRTSubgraphPassDepthwiseConvTransposeTest
(
class
TensorRTSubgraphPassDepthwiseConvTransposeTest
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录