未验证 提交 e90f3006 编写于 作者: W wenbin 提交者: GitHub

强化非trt conv判断 (#33150)

* add more conditions

* dynamic shape

* ut

* correct contidions

* commnent

* remove rebandadnt op type

* remove rebandant if
上级 5b910f95
...@@ -143,19 +143,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -143,19 +143,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) return false; if (paddings.size() > 2) return false;
// strides > 1 is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000)
if (desc.HasAttr("strides")) {
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
// there is no issue if strides.size() less than 2
if (strides.size() > 1) {
for (size_t i = 0; i < strides.size(); i++) {
if (strides[i] > 1) return false;
}
}
}
#endif
} }
if (op_type == "pool2d") { if (op_type == "pool2d") {
...@@ -239,15 +226,22 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -239,15 +226,22 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
// strides > 1 is only supported by trt7.0 above // strides > 1 and 'SAME' is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000) #if !IS_TRT_VERSION_GE(7000)
if (desc.HasAttr("strides")) { if (op_type == "conv2d" || op_type == "conv2d_fusion" ||
const std::vector<int> strides = op_type == "depthwise_conv2d") {
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides")); if (desc.HasAttr("padding_algorithm") && with_dynamic_shape) {
// there is no issue if strides.size() less than 2 auto padding_algorithm =
if (strides.size() > 1) { BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
for (size_t i = 0; i < strides.size(); i++) { if (padding_algorithm == "SAME" && desc.HasAttr("strides")) {
if (strides[i] > 1) return false; const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
// there is no issue if strides.size() less than 2
if (strides.size() > 1) {
for (size_t i = 0; i < strides.size(); i++) {
if (strides[i] > 1) return false;
}
}
} }
} }
} }
......
...@@ -161,5 +161,70 @@ class TensorRTSubgraphPassDepthwiseConvTransposeTest( ...@@ -161,5 +161,70 @@ class TensorRTSubgraphPassDepthwiseConvTransposeTest(
self.use_cudnn = False self.use_cudnn = False
class DynamicShapeTensorRTSubgraphPassConvTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, -1, -1], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
use_cudnn=self.use_cudnn,
stride=self.stride,
act=None)
self.feeds = {
"data": np.random.random([32, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = DynamicShapeTensorRTSubgraphPassConvTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConvTest.DynamicShapeParam(
{
"conv2d_0.tmp_0": [1, 6, 8, 8],
"data": [1, 6, 8, 8],
"depthwise_conv2d_0.tmp_0": [1, 6, 8, 8]
}, {
"conv2d_0.tmp_0": [32, 6, 64, 64],
"data": [32, 6, 64, 64],
"depthwise_conv2d_0.tmp_0": [32, 6, 64, 64]
}, {
"conv2d_0.tmp_0": [16, 6, 16, 16],
"data": [16, 6, 16, 16],
"depthwise_conv2d_0.tmp_0": [32, 6, 64, 64]
}, False)
self.fetch_list = [conv_out]
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 6
self.conv_padding = 'SAME'
self.use_cudnn = True
self.stride = [2, 2]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class DynamicShapeTensorRTSubgraphPassDepthwiseConvTransposeTest(
DynamicShapeTensorRTSubgraphPassConvTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 6
self.conv_padding = 'SAME'
self.use_cudnn = False
self.stride = [2, 2]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册