Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
42847d2e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
42847d2e
编写于
9月 10, 2021
作者:
W
wenbin
提交者:
GitHub
9月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
conv3d (#35507)
* conv3d * remove const_cast * modify ut * disable dynamic shape for trt6.0 * remove trt5
上级
512329b0
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
583 addition
and
50 deletion
+583
-50
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/convert/conv3d_op.cc
paddle/fluid/inference/tensorrt/convert/conv3d_op.cc
+170
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+9
-5
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+104
-45
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py
.../fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py
+158
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py
...ts/unittests/ir/inference/test_trt_conv3d_transpose_op.py
+137
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
42847d2e
...
...
@@ -1257,6 +1257,8 @@ USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER
(
gather_nd
);
USE_TRT_CONVERTER
(
reduce_mean
);
USE_TRT_CONVERTER
(
tile
);
USE_TRT_CONVERTER
(
conv3d
);
USE_TRT_CONVERTER
(
conv3d_transpose
);
#endif
namespace
paddle_infer
{
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
42847d2e
...
...
@@ -16,6 +16,7 @@ nv_library(tensorrt_converter
reduce_op.cc
gather_nd_op.cc
tile_op.cc
conv3d_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
...
...
paddle/fluid/inference/tensorrt/convert/conv3d_op.cc
0 → 100644
浏览文件 @
42847d2e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
template
<
typename
RegistFunc
,
typename
SetDilationFunc
>
void
ConvertConv3d
(
TensorRTEngine
*
engine
,
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
,
RegistFunc
fadd_layer
,
SetDilationFunc
fset_dilation
,
const
std
::
string
&
name
)
{
VLOG
(
3
)
<<
"convert a fluid "
<<
name
<<
" op to tensorrt layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
X
=
engine
->
GetITensor
(
op_desc
.
Input
(
"Input"
).
front
());
std
::
string
filter_var_name
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
*
Y_v
=
scope
.
FindVar
(
filter_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
Y_v
,
platform
::
errors
::
NotFound
(
"Can not find %s presistale var in scope."
,
filter_var_name
));
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
if
(
enable_int8
)
{
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
,
true
,
weight_scale
);
engine
->
SetTensorDynamicRange
(
X
,
in_scale
);
}
else
{
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
,
false
);
}
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
5UL
,
platform
::
errors
::
InvalidArgument
(
"The conv3d filter's dims size should be 5, but got %d"
,
Y_t
->
dims
().
size
()));
const
int
n_output
=
Y_t
->
dims
()[
0
];
const
int
n_input
=
Y_t
->
dims
()[
1
];
const
int
filter_d
=
Y_t
->
dims
()[
2
];
const
int
filter_h
=
Y_t
->
dims
()[
3
];
const
int
filter_w
=
Y_t
->
dims
()[
4
];
const
int
groups
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"groups"
));
const
std
::
vector
<
int
>
dilations
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"dilations"
));
const
std
::
vector
<
int
>
strides
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"strides"
));
const
std
::
vector
<
int
>
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
.
GetAttr
(
"paddings"
));
std
::
string
padding_algorithm
=
"EXPLICIT"
;
if
(
op_desc
.
HasAttr
(
"padding_algorithm"
))
padding_algorithm
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"padding_algorithm"
));
nvinfer1
::
Dims3
nv_ksize
(
filter_d
,
filter_h
,
filter_w
);
nvinfer1
::
Dims3
nv_dilations
(
dilations
[
0
],
dilations
[
1
],
dilations
[
2
]);
nvinfer1
::
Dims3
nv_strides
(
strides
[
0
],
strides
[
1
],
strides
[
2
]);
nvinfer1
::
Dims3
nv_paddings
(
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
Y_t
->
numel
())};
float
*
bias_data
=
nullptr
;
size_t
bias_size
=
0
;
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
bias_size
};
// In conv3d_transpose output channels = filter_dims[1] * groups
auto
*
layer
=
(
op_desc
.
Type
()
==
"conv3d_transpose"
)
?
fadd_layer
(
X
,
n_input
*
groups
,
nv_ksize
,
weight
,
bias
)
:
fadd_layer
(
X
,
n_output
,
nv_ksize
,
weight
,
bias
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
platform
::
errors
::
Fatal
(
"TensorRT create conv3d/conv3d_transpose"
" layer failed."
));
layer
->
setStrideNd
(
nv_strides
);
layer
->
setPaddingNd
(
nv_paddings
);
layer
->
setNbGroups
(
groups
);
if
(
padding_algorithm
==
"SAME"
)
{
layer
->
setPaddingMode
(
nvinfer1
::
PaddingMode
::
kSAME_UPPER
);
}
// set dilations
fset_dilation
(
layer
,
nv_dilations
);
auto
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
layer
->
setName
((
name
+
" (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine
->
DeclareOutput
(
output_name
);
}
}
class
Conv3dOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
ConvertConv3d
(
engine_
,
op
,
scope
,
test_mode
,
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
/* Conv output maps */
nvinfer1
::
Dims
&
ksize
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
bias
)
->
nvinfer1
::
IConvolutionLayer
*
{
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ConvolutionNd
,
*
inputs
,
n_output
,
ksize
,
weight
.
get
(),
bias
.
get
());
return
layer
;
},
[](
nvinfer1
::
IConvolutionLayer
*
layer
,
nvinfer1
::
Dims
&
dilations
)
{
layer
->
setDilationNd
(
dilations
);
},
"conv3d"
);
}
};
class
Deconv3dOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
ConvertConv3d
(
engine_
,
op
,
scope
,
test_mode
,
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
/* Deconv input maps */
nvinfer1
::
Dims
&
ksize
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
bias
)
->
nvinfer1
::
IDeconvolutionLayer
*
{
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
DeconvolutionNd
,
*
inputs
,
n_output
,
ksize
,
weight
.
get
(),
bias
.
get
());
return
layer
;
},
[](
nvinfer1
::
IDeconvolutionLayer
*
layer
,
nvinfer1
::
Dims
&
dilations
)
{},
"conv3d_transpose"
);
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
conv3d
,
Conv3dOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
conv3d_transpose
,
Deconv3dOpConverter
);
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
42847d2e
...
...
@@ -76,11 +76,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
"TensorRT's tensor input requires at least 1 "
"dimensions, but input %s has %d dims."
,
input
,
shape
.
size
()));
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
platform
::
errors
::
InvalidArgument
(
"TensorRT's tensor input requires at most 4 "
"dimensions, but input %s has %d dims."
,
input
,
shape
.
size
()));
auto
ShapeStr
=
[](
const
std
::
vector
<
T
>&
shape
)
{
std
::
ostringstream
os
;
os
<<
"["
;
...
...
@@ -103,6 +99,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
input
,
ShapeStr
(
shape
)));
}
return
nvinfer1
::
Dims3
(
shape
[
1
],
shape
[
2
],
shape
[
3
]);
}
else
if
(
shape
.
size
()
==
5UL
)
{
if
(
shape
[
2
]
==
-
1
||
shape
[
3
]
==
-
1
||
shape
[
4
]
==
-
1
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo."
,
input
,
ShapeStr
(
shape
)));
}
return
nvinfer1
::
Dims4
(
shape
[
1
],
shape
[
2
],
shape
[
3
],
shape
[
4
]);
}
else
if
(
shape
.
size
()
==
3UL
)
{
if
(
shape
[
1
]
==
-
1
||
shape
[
2
]
==
-
1
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
42847d2e
...
...
@@ -90,51 +90,51 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_mul"
,
"conv2d_transpose"
,
"hard_swish"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
"
mul"
,
"matmul
"
,
"conv2d
"
,
"conv2d_fusion
"
,
"pool2d
"
,
"relu
"
,
"softmax
"
,
"sigmoid
"
,
"hard_swish
"
,
"depthwise_conv2d
"
,
"batch_norm
"
,
"concat
"
,
"tanh
"
,
"pa
d"
,
"elementwise_add
"
,
"elementwise_mul
"
,
"dropout
"
,
"prelu
"
,
"
conv2d_transpose"
,
"depthwise_conv2d_transpose
"
,
"leaky_relu
"
,
"fc
"
,
"shuffle_channel
"
,
"swish
"
,
"split
"
,
"instance_norm
"
,
"gelu
"
,
"layer_norm
"
,
"scale
"
,
"stack
"
,
"transpose2
"
,
"transpose
"
,
"flatten2
"
,
"flatten
"
,
"gather
"
,
"gather_nd
"
,
"yolo_box
"
,
"roi_align
"
,
"affine_channel
"
,
"nearest_interp
"
,
"anchor_generator
"
,
"reduce_sum
"
,
"reduce_mean
"
,
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
"mat
mul"
,
"conv2d
"
,
"conv2d_fusion
"
,
"pool2d
"
,
"relu
"
,
"softmax
"
,
"sigmoid
"
,
"hard_swish
"
,
"depthwise_conv2d
"
,
"batch_norm
"
,
"concat
"
,
"tanh
"
,
"pad
"
,
"elementwise_ad
d"
,
"elementwise_mul
"
,
"dropout
"
,
"prelu
"
,
"conv2d_transpose
"
,
"depthwise_
conv2d_transpose"
,
"leaky_relu
"
,
"fc
"
,
"shuffle_channel
"
,
"swish
"
,
"split
"
,
"instance_norm
"
,
"gelu
"
,
"layer_norm
"
,
"scale
"
,
"stack
"
,
"transpose2
"
,
"transpose
"
,
"flatten2
"
,
"flatten
"
,
"gather
"
,
"gather_nd
"
,
"yolo_box
"
,
"roi_align
"
,
"affine_channel
"
,
"nearest_interp
"
,
"anchor_generator
"
,
"reduce_sum
"
,
"reduce_mean
"
,
"conv3d
"
,
"conv3d_transpose"
};
};
bool
OpTeller
::
Tell
(
const
framework
::
ir
::
Node
*
node
,
bool
use_no_calib_int8
,
...
...
@@ -767,6 +767,65 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
#endif
if
(
op_type
==
"conv3d"
||
op_type
==
"conv3d_transpose"
)
{
if
(
desc
.
HasAttr
(
"padding_algorithm"
))
{
std
::
string
padding_algorithm
=
BOOST_GET_CONST
(
std
::
string
,
desc
.
GetAttr
(
"padding_algorithm"
));
// trt error is arised if conv3d_transpose and SAME
if
(
op_type
==
"conv3d_transpose"
&&
padding_algorithm
==
"SAME"
&&
!
with_dynamic_shape
)
{
return
false
;
}
}
#if !IS_TRT_VERSION_GE(7000)
// looks like some issues with trt6.0
if
(
with_dynamic_shape
)
{
return
false
;
}
#endif
std
::
vector
<
int
>
paddings
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"paddings"
));
// conv3d and conv3d_transpose need padding check
if
(
paddings
.
size
()
>
3
)
return
false
;
if
(
desc
.
Input
(
"Input"
).
size
()
!=
1
)
{
VLOG
(
3
)
<<
"TRT Conv3d expect 1 input, but got "
<<
desc
.
Input
(
"Input"
).
size
()
<<
" input."
;
return
false
;
}
if
(
desc
.
Input
(
"Filter"
).
size
()
!=
1
)
{
VLOG
(
3
)
<<
"TRT Conv3d expect 1 filter, but got "
<<
desc
.
Input
(
"Filter"
).
size
()
<<
" filter."
;
return
false
;
}
if
(
op_type
==
"conv3d_transpose"
)
{
if
(
!
desc
.
HasAttr
(
"dilations"
))
{
return
false
;
}
else
{
const
std
::
vector
<
int
>
dilations
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
desc
.
GetAttr
(
"dilations"
));
if
(
dilations
[
0
]
!=
1
||
dilations
[
1
]
!=
1
||
dilations
[
2
]
!=
1
)
{
VLOG
(
3
)
<<
"In conv3d_transpose, Dilations must be (1, 1, 1) for "
"tensorRT, but given ("
<<
dilations
[
0
]
<<
", "
<<
dilations
[
1
]
<<
", "
<<
dilations
[
2
]
<<
")"
;
return
false
;
}
}
}
if
(
desc
.
Output
(
"Output"
).
size
()
!=
1
)
{
VLOG
(
3
)
<<
"TRT Conv3d expect 1 output, but got "
<<
desc
.
Output
(
"Output"
).
size
()
<<
" output."
;
return
false
;
}
}
if
((
*
teller
)(
op_type
,
desc
,
use_no_calib_int8
))
return
true
;
}
...
...
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
42847d2e
...
...
@@ -66,4 +66,6 @@ set_tests_properties(test_trt_tile_op PROPERTIES TIMEOUT 60)
set_tests_properties
(
test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100
)
set_tests_properties
(
test_trt_conv3d_op PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60
)
endif
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_op.py
0 → 100644
浏览文件 @
42847d2e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TensorRTSubgraphPassConv3dTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
init_params
()
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
6
,
32
,
32
],
dtype
=
"float32"
)
conv_out
=
fluid
.
layers
.
conv3d
(
input
=
data
,
num_filters
=
self
.
conv_num_filters
,
filter_size
=
self
.
conv_filter_size
,
groups
=
self
.
conv_groups
,
padding
=
self
.
conv_padding
,
bias_attr
=
False
,
use_cudnn
=
self
.
use_cudnn
,
stride
=
self
.
stride
,
act
=
None
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
1
,
3
,
6
,
32
,
32
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TensorRTSubgraphPassConv3dTest
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
self
.
precision
,
self
.
use_static
,
False
)
self
.
fetch_list
=
[
conv_out
]
def
init_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
3
self
.
conv_padding
=
[
1
,
1
,
1
]
self
.
use_cudnn
=
True
self
.
use_static
=
False
self
.
precision
=
AnalysisConfig
.
Precision
.
Float32
self
.
stride
=
1
def
set_params
(
self
):
pass
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TensorRTSubgraphPassConv3dValidPaddingTest
(
TensorRTSubgraphPassConv3dTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
3
self
.
conv_padding
=
'VALID'
class
TensorRTSubgraphPassConv3dSamePaddingTest
(
TensorRTSubgraphPassConv3dTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
3
self
.
conv_padding
=
'SAME'
class
TensorRTSubgraphPassConv3dPaddingTest
(
TensorRTSubgraphPassConv3dTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
3
self
.
conv_padding
=
[
2
,
3
,
3
]
class
TensorRTSubgraphPassConv3dStrideTest
(
TensorRTSubgraphPassConv3dTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
3
self
.
conv_padding
=
'SAME'
self
.
stride
=
[
1
,
2
,
2
]
class
DynamicShapeTensorRTSubgraphPassConv3dTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
6
,
-
1
,
-
1
,
-
1
],
dtype
=
"float32"
)
conv_out
=
fluid
.
layers
.
conv3d
(
input
=
data
,
num_filters
=
self
.
conv_num_filters
,
filter_size
=
self
.
conv_filter_size
,
groups
=
self
.
conv_groups
,
padding
=
self
.
conv_padding
,
bias_attr
=
False
,
use_cudnn
=
self
.
use_cudnn
,
stride
=
self
.
stride
,
act
=
None
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
1
,
6
,
32
,
32
,
8
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
DynamicShapeTensorRTSubgraphPassConv3dTest
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
dynamic_shape_params
=
DynamicShapeTensorRTSubgraphPassConv3dTest
.
DynamicShapeParam
(
{
"data"
:
[
1
,
6
,
8
,
8
,
8
],
"conv3d_0.tmp_0"
:
[
1
,
6
,
8
,
8
,
4
],
},
{
"data"
:
[
32
,
6
,
32
,
32
,
8
],
"conv3d_0.tmp_0"
:
[
32
,
6
,
32
,
32
,
8
],
},
{
"data"
:
[
16
,
6
,
16
,
16
,
8
],
"conv3d_0.tmp_0"
:
[
16
,
6
,
16
,
16
,
8
],
},
False
)
self
.
fetch_list
=
[
conv_out
]
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
6
self
.
conv_padding
=
'SAME'
self
.
use_cudnn
=
True
self
.
stride
=
[
2
,
2
,
2
]
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv3d_transpose_op.py
0 → 100644
浏览文件 @
42847d2e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TensorRTSubgraphPassConv3dTransposeTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
4
,
4
,
32
,
32
],
dtype
=
"float32"
)
conv_out
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data
,
num_filters
=
self
.
conv_num_filters
,
filter_size
=
self
.
conv_filter_size
,
groups
=
self
.
conv_groups
,
padding
=
self
.
conv_padding
,
bias_attr
=
False
,
use_cudnn
=
self
.
use_cudnn
,
stride
=
1
,
act
=
None
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
1
,
4
,
4
,
32
,
32
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TensorRTSubgraphPassConv3dTransposeTest
.
TensorRTParam
(
1
<<
30
,
32
,
1
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
conv_out
]
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
1
self
.
conv_padding
=
[
1
,
1
,
1
]
self
.
use_cudnn
=
True
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TensorRTSubgraphPassConv3dTransposeSamePaddingTest
(
TensorRTSubgraphPassConv3dTransposeTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
1
self
.
conv_padding
=
'VALID'
self
.
use_cudnn
=
True
class
TensorRTSubgraphPassConv3dTransposeMultigroupTest
(
TensorRTSubgraphPassConv3dTransposeTest
):
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
2
self
.
conv_padding
=
'VALID'
self
.
use_cudnn
=
True
class
DynamicShapeTensorRTSubgraphPassConv3dTransposeTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
6
,
-
1
,
-
1
,
-
1
],
dtype
=
"float32"
)
conv_out
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data
,
num_filters
=
self
.
conv_num_filters
,
filter_size
=
self
.
conv_filter_size
,
groups
=
self
.
conv_groups
,
padding
=
self
.
conv_padding
,
bias_attr
=
False
,
use_cudnn
=
self
.
use_cudnn
,
stride
=
self
.
stride
,
act
=
None
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
([
1
,
6
,
32
,
32
,
8
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
DynamicShapeTensorRTSubgraphPassConv3dTransposeTest
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
dynamic_shape_params
=
DynamicShapeTensorRTSubgraphPassConv3dTransposeTest
.
DynamicShapeParam
(
{
"data"
:
[
1
,
6
,
8
,
8
,
8
],
"conv3d_transpose_0.tmp_0"
:
[
1
,
6
,
8
,
8
,
1
],
},
{
"data"
:
[
32
,
6
,
32
,
32
,
8
],
"conv3d_transpose_0.tmp_0"
:
[
32
,
6
,
64
,
64
,
16
],
},
{
"data"
:
[
16
,
6
,
16
,
16
,
8
],
"conv3d_transpose_0.tmp_0"
:
[
16
,
6
,
16
,
16
,
8
],
},
False
)
self
.
fetch_list
=
[
conv_out
]
def
set_params
(
self
):
self
.
conv_num_filters
=
6
self
.
conv_filter_size
=
6
self
.
conv_groups
=
6
self
.
conv_padding
=
'SAME'
self
.
use_cudnn
=
True
self
.
stride
=
[
2
,
2
,
2
]
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录