Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7766721a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7766721a
编写于
5月 31, 2021
作者:
W
wenbin
提交者:
GitHub
5月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
disable conv plugin in TRT old versions (#33198)
上级
d7d3090f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
86 addition
and
19 deletion
+86
-19
paddle/fluid/inference/tensorrt/convert/activation_op.cc
paddle/fluid/inference/tensorrt/convert/activation_op.cc
+0
-5
paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
+0
-10
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+0
-4
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+21
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py
.../fluid/tests/unittests/ir/inference/test_trt_conv_pass.py
+65
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/activation_op.cc
浏览文件 @
7766721a
...
@@ -52,11 +52,6 @@ class ActivationOpConverter : public OpConverter {
...
@@ -52,11 +52,6 @@ class ActivationOpConverter : public OpConverter {
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
op_pair
=
ops
.
find
(
op_type_
);
auto
op_pair
=
ops
.
find
(
op_type_
);
if
(
op_pair
==
ops
.
end
())
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Wrong activation op type, the trt do not support the %s act type."
,
op_type_
));
}
nvinfer1
::
IActivationLayer
*
layer
=
TRT_ENGINE_ADD_LAYER
(
nvinfer1
::
IActivationLayer
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Activation
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input_tensor
),
engine_
,
Activation
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input_tensor
),
...
...
paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
浏览文件 @
7766721a
...
@@ -55,16 +55,6 @@ class AffineChannelOpConverter : public OpConverter {
...
@@ -55,16 +55,6 @@ class AffineChannelOpConverter : public OpConverter {
auto
*
bias_t
=
bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias_t
=
bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
bias_ptr
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
,
false
);
float
*
bias_ptr
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
,
false
);
auto
data_layout
=
framework
::
StringToDataLayout
(
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"data_layout"
)));
PADDLE_ENFORCE_EQ
(
data_layout
,
framework
::
DataLayout
::
kNCHW
,
platform
::
errors
::
InvalidArgument
(
"TensorRT affine channel converter can only convert NCHW format. "
"Other format should be run in fluid mode. Report a bug on github "
"issue if you see this line."
));
// tensorrt scalend layer only support spatial dims >= 2,
// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
// so nhwc is not availabe (spatial dims == 0)
const
int
channel_axis
=
engine_
->
with_dynamic_shape
();
const
int
channel_axis
=
engine_
->
with_dynamic_shape
();
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
7766721a
...
@@ -25,10 +25,6 @@ static bool CheckDims(const nvinfer1::Dims& dims_x,
...
@@ -25,10 +25,6 @@ static bool CheckDims(const nvinfer1::Dims& dims_x,
return
false
;
return
false
;
}
}
for
(
int
i
=
0
;
i
<
dims_x
.
nbDims
;
i
++
)
{
for
(
int
i
=
0
;
i
<
dims_x
.
nbDims
;
i
++
)
{
// conservative judgment
if
(
dims_x
.
d
[
i
]
==
-
1
||
dims_y
.
d
[
i
]
==
-
1
)
{
return
false
;
}
if
(
dims_x
.
d
[
i
]
!=
dims_y
.
d
[
i
])
{
if
(
dims_x
.
d
[
i
]
!=
dims_y
.
d
[
i
])
{
return
false
;
return
false
;
}
}
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
7766721a
...
@@ -225,6 +225,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...
@@ -225,6 +225,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<<
desc
.
Output
(
"Output"
).
size
()
<<
" output."
;
<<
desc
.
Output
(
"Output"
).
size
()
<<
" output."
;
return
false
;
return
false
;
}
}
// strides > 1 and 'SAME' is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000)
if
(
op_type
==
"conv2d"
||
op_type
==
"conv2d_fusion"
||
op_type
==
"depthwise_conv2d"
)
{
if
(
desc
.
HasAttr
(
"padding_algorithm"
)
&&
with_dynamic_shape
)
{
auto
padding_algorithm
=
BOOST_GET_CONST
(
std
::
string
,
desc
.
GetAttr
(
"padding_algorithm"
));
if
(
padding_algorithm
==
"SAME"
&&
desc
.
HasAttr
(
"strides"
))
{
const
std
::
vector
<
int
>
strides
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"strides"
));
// there is no issue if strides.size() less than 2
if
(
strides
.
size
()
>
1
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
i
++
)
{
if
(
strides
[
i
]
>
1
)
return
false
;
}
}
}
}
}
#endif
}
}
if
(
op_type
==
"matmul"
)
{
if
(
op_type
==
"matmul"
)
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py
浏览文件 @
7766721a
...
@@ -161,5 +161,70 @@ class TensorRTSubgraphPassDepthwiseConvTransposeTest(
...
@@ -161,5 +161,70 @@ class TensorRTSubgraphPassDepthwiseConvTransposeTest(
self
.
use_cudnn
=
False
self
.
use_cudnn
=
False
class
DynamicShapeTensorRTSubgraphPassConvTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
6
,
-
1
,
-
1
],
dtype
=
"float32"
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
self
.
conv_num_filters
,
filter_size
=
self
.
conv_filter_size
,
groups
=
self
.
conv_groups
,
padding
=
self
.
conv_padding
,
bias_attr
=
False
,
use_cudnn
=
self
.
use_cudnn
,
stride
=
self
.
stride
,
act
=
None
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
32
,
6
,
64
,
64
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
DynamicShapeTensorRTSubgraphPassConvTest
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
dynamic_shape_params
=
DynamicShapeTensorRTSubgraphPassConvTest
.
DynamicShapeParam
(
{
"conv2d_0.tmp_0"
:
[
1
,
6
,
8
,
8
],
"data"
:
[
1
,
6
,
8
,
8
],
"depthwise_conv2d_0.tmp_0"
:
[
1
,
6
,
8
,
8
]
},
{
"conv2d_0.tmp_0"
:
[
32
,
6
,
64
,
64
],
"data"
:
[
32
,
6
,
64
,
64
],
"depthwise_conv2d_0.tmp_0"
:
[
32
,
6
,
64
,
64
]
},
{
"conv2d_0.tmp_0"
:
[
16
,
6
,
16
,
16
],
"data"
:
[
16
,
6
,
16
,
16
],
"depthwise_conv2d_0.tmp_0"
:
[
32
,
6
,
64
,
64
]
},
False
)
self
.
fetch_list
=
[
conv_out
]
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
6
self
.
conv_padding
=
'SAME'
self
.
use_cudnn
=
True
self
.
stride
=
[
2
,
2
]
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
DynamicShapeTensorRTSubgraphPassDepthwiseConvTransposeTest
(
DynamicShapeTensorRTSubgraphPassConvTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
6
self
.
conv_padding
=
'SAME'
self
.
use_cudnn
=
False
self
.
stride
=
[
2
,
2
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录