From 82f255d0a5ab93b8996897cb8d6a28484694c91a Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Fri, 24 Sep 2021 03:14:07 -0500 Subject: [PATCH] add pool2d convert test (#35923) * add pool2d convert test * modify error * modify error * modify error * modify error * modify error * modify error --- .../inference/tensorrt/convert/pool2d_op.cc | 13 + paddle/fluid/inference/tensorrt/op_teller.cc | 20 ++ .../test_trt_convert_anchor_generator.py | 116 +++++++++ .../test_trt_convert_conv2d_transpose.py | 227 ++++++++++++++++++ .../test_trt_convert_depthwise_conv2d.py | 203 ++++++++++++++++ ..._trt_convert_depthwise_conv2d_transpose.py | 191 +++++++++++++++ .../ir/inference/test_trt_convert_pool2d.py | 148 ++++++++++++ 7 files changed, 918 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_anchor_generator.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d_transpose.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d_transpose.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 90d6392fd64..1898f28c73a 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -87,6 +87,10 @@ class Pool2dOpConverter : public OpConverter { bool adaptive = false; if (op_desc.HasAttr("adaptive")) adaptive = BOOST_GET_CONST(bool, op_desc.GetAttr("adaptive")); + std::string padding_algorithm = "EXPLICIT"; + if (op_desc.HasAttr("padding_algorithm")) + padding_algorithm = + BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm")); nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX; nvinfer1::ReduceOperation reduce_operation = @@ -124,6 +128,9 @@ class Pool2dOpConverter : public OpConverter { pool_layer->setStride(nv_strides); pool_layer->setPadding(nv_paddings); pool_layer->setAverageCountExcludesPadding(exclusive); + if (padding_algorithm == "SAME") { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } layer = pool_layer; } else if (global_pooling) { auto *reduce_layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *input1, @@ -159,6 +166,9 @@ class Pool2dOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; pool_layer->setStride(nv_strides); pool_layer->setPadding(nv_paddings); + if (padding_algorithm == "SAME") { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } pool_layer->setAverageCountExcludesPadding(exclusive); pool_layer->setName(("pool2d (Output: " + output_name + ")").c_str()); pool_layer->getOutput(0)->setName(output_name.c_str()); @@ -198,6 +208,9 @@ class Pool2dOpConverter : public OpConverter { "trt pool layer in converter could not be created.")); pool_layer->setStride(nv_strides); pool_layer->setPadding(nv_paddings); + if (padding_algorithm == "SAME") { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } pool_layer->setAverageCountExcludesPadding(exclusive); layer = pool_layer; } else { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ea630a9c6db..5bfd2f12777 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -172,6 +172,22 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, std::vector paddings = BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); if (paddings.size() > 2) return false; + if (desc.HasAttr("exclusive")) { + if (BOOST_GET_CONST(bool, desc.GetAttr("exclusive"))) { + std::vector ksize = + BOOST_GET_CONST(std::vector, desc.GetAttr("ksize")); + for (size_t i = 0; i < ksize.size(); i++) { + if (ksize[i] <= paddings[i]) { + VLOG(3) << "the padding size should be less than the filter size " + "for exclusive-counting pooling."; + return false; + } + } + } + } + if (desc.HasAttr("ceil_mode")) { + if (BOOST_GET_CONST(bool, desc.GetAttr("ceil_mode"))) return false; + } if (desc.Input("X").size() != 1) { VLOG(3) << "TRT Pool2d expect 1 input, but got " << desc.Input("X").size(); @@ -440,6 +456,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } + if (op_type == "anchor_generator") { + if (!with_dynamic_shape) return false; + } + if (op_type == "yolo_box") { if (with_dynamic_shape) return false; bool has_attrs = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_anchor_generator.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_anchor_generator.py new file mode 100644 index 00000000000..bf457a9da40 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_anchor_generator.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertAnchorGeneratorTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(batch, attrs: List[Dict[str, Any]]): + return np.random.random([batch, 3, 64, 64]).astype(np.float32) + + for batch in [1, 2, 4]: + for anchor_sizes in [[64.0, 128.0, 256.0, 512.0]]: + for aspect_ratios in [[0.5, 1, 2], [0.4, 1.2, 3]]: + for variances in [[1.0, 1.0, 1.0, 1.0], + [0.5, 1.0, 0.5, 1.0]]: + for stride in [[16.0, 16.0], [16.0, 32.0]]: + for offset in [0.5, 0.8]: + dics = [{ + "anchor_sizes": anchor_sizes, + "aspect_ratios": aspect_ratios, + "variances": variances, + "stride": stride, + "offset": offset + }] + + ops_config = [{ + "op_type": "anchor_generator", + "op_inputs": { + "Input": ["input_data"] + }, + "op_outputs": { + "Anchors": ["output_anchors"], + "Variances": ["output_variances"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, + batch, dics)) + }, + outputs=[ + "output_anchors", "output_variances" + ]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 3 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d_transpose.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d_transpose.py new file mode 100644 index 00000000000..82dd492b527 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d_transpose.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertConv2dTransposeTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if inputs['input_data'].shape[1] != weights['conv2d_weight'].shape[ + 1] * attrs[0]['groups']: + return False + + if inputs['input_data'].shape[1] != weights['conv2d_weight'].shape[0]: + return False + + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(batch, num_channels, attrs: List[Dict[str, Any]]): + return np.ones([batch, num_channels, 64, 64]).astype(np.float32) + + def generate_weight1(num_channels, attrs: List[Dict[str, Any]]): + if attrs[0]['groups'] == 1: + return np.random.random( + [num_channels, num_channels, 3, 3]).astype(np.float32) + else: + return np.random.random( + [num_channels, int(num_channels / 2), 3, 3]).astype( + np.float32) + + for num_channels in [2, 4, 6]: + for batch in [1, 2, 4]: + for strides in [[1, 1], [2, 2], [1, 2]]: + for paddings in [[0, 3], [1, 2, 3, 4]]: + for groups in [2]: + for padding_algorithm in [ + 'EXPLICIT', 'SAME', 'VALID' + ]: + for dilations in [[1, 1], [2, 2], [1, 2]]: + for data_format in ['NCHW']: + + self.num_channels = num_channels + dics = [{ + "data_fromat": data_format, + "dilations": dilations, + "padding_algorithm": + padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides, + "data_format": data_format, + "output_size": [], + "output_padding": [] + }] + + ops_config = [{ + "op_type": "conv2d_transpose", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["conv2d_weight"] + }, + "op_outputs": { + "Output": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config( + ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial( + generate_weight1, + num_channels, dics)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial( + generate_input1, batch, + num_channels, dics)) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.num_channels == 2: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 2, 32, 32], + "output_data": [1, 24, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 2, 64, 64], + "output_data": [4, 24, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 2, 64, 64], + "output_data": [1, 24, 64, 64] + } + elif self.num_channels == 4: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 4, 32, 32], + "output_data": [1, 24, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 4, 64, 64], + "output_data": [4, 24, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 4, 64, 64], + "output_data": [1, 24, 64, 64] + } + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 6, 32, 32], + "output_data": [1, 24, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 6, 64, 64], + "output_data": [4, 24, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 6, 64, 64], + "output_data": [1, 24, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs[ + 'padding_algorithm'] == "SAME" or program_config.ops[ + 0].attrs['padding_algorithm'] == "VALID": + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_IMPLEMENTED, + "When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op." + ) + + def teller2(program_config, predictor_config): + if program_config.ops[0].attrs['dilations'][ + 0] != 1 or program_config.ops[0].attrs['dilations'][1] != 1: + return True + return False + + self.add_skip_case( + teller2, SkipReasons.TRT_NOT_IMPLEMENTED, + "When dilations's element is not equal 1, there are different behaviors between Trt and Paddle." + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + def test_quant(self): + self.add_skip_trt_case() + self.run_test(quant=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d.py new file mode 100644 index 00000000000..e6b3aa30bf8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertDepthwiseConv2dTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if inputs['input_data'].shape[1] != weights['conv2d_weight'].shape[ + 1] * attrs[0]['groups']: + return False + + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(batch, attrs: List[Dict[str, Any]]): + if attrs[0]['groups'] == 1: + return np.ones([batch, 1, 64, 64]).astype(np.float32) + elif attrs[0]['groups'] == 2: + return np.ones([batch, 2, 64, 64]).astype(np.float32) + else: + return np.ones([batch, 3, 64, 64]).astype(np.float32) + + def generate_weight1(attrs: List[Dict[str, Any]]): + return np.random.random([24, 1, 3, 3]).astype(np.float32) + + for batch in [1, 2, 4]: + for strides in [[1, 1], [2, 2], [1, 2]]: + for paddings in [[0, 3], [1, 2, 3, 4]]: + for groups in [1, 2, 3]: + for padding_algorithm in ['EXPLICIT', 'SAME', 'VALID']: + for dilations in [[1, 1], [2, 2], [1, 2]]: + for data_format in ['NCHW']: + + dics = [{ + "data_fromat": data_format, + "dilations": dilations, + "padding_algorithm": padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides, + "data_format": data_format + }] + + ops_config = [{ + "op_type": "depthwise_conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["conv2d_weight"] + }, + "op_outputs": { + "Output": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial( + generate_weight1, dics)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial( + generate_input1, batch, dics)) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if attrs[0]['groups'] == 1: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 1, 32, 32], + "output_data": [1, 24, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 1, 64, 64], + "output_data": [4, 24, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 1, 64, 64], + "output_data": [1, 24, 64, 64] + } + elif attrs[0]['groups'] == 2: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 2, 32, 32], + "output_data": [1, 24, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 2, 64, 64], + "output_data": [4, 24, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 2, 64, 64], + "output_data": [1, 24, 64, 64] + } + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 32, 32], + "output_data": [1, 24, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 64, 64], + "output_data": [4, 24, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64], + "output_data": [1, 24, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + + # for dynamic_shape + + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs[ + 'padding_algorithm'] == "SAME" or program_config.ops[ + 0].attrs['padding_algorithm'] == "VALID": + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_IMPLEMENTED, + "When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op." + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + def test_quant(self): + self.add_skip_trt_case() + self.run_test(quant=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d_transpose.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d_transpose.py new file mode 100644 index 00000000000..473925c6cdb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_depthwise_conv2d_transpose.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertDepthwiseConv2dTransposeTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if inputs['input_data'].shape[1] != weights['conv2d_weight'].shape[ + 1] * attrs[0]['groups']: + return False + + if inputs['input_data'].shape[1] != weights['conv2d_weight'].shape[1]: + return False + + if inputs['input_data'].shape[1] != attrs[0]['groups']: + return False + + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(batch, attrs: List[Dict[str, Any]]): + return np.ones( + [batch, attrs[0]['groups'], 64, 64]).astype(np.float32) + + def generate_weight1(attrs: List[Dict[str, Any]]): + return np.random.random( + [attrs[0]['groups'], 1, 3, 3]).astype(np.float32) + + for batch in [1, 2, 4]: + for strides in [[1, 1], [2, 2], [1, 2]]: + for paddings in [[0, 3], [1, 2, 3, 4]]: + for groups in [1, 2, 3]: + for padding_algorithm in ['EXPLICIT', 'SAME', 'VALID']: + for dilations in [[1, 1], [2, 2], [1, 2]]: + for data_format in ['NCHW']: + + dics = [{ + "data_fromat": data_format, + "dilations": dilations, + "padding_algorithm": padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides, + "data_format": data_format, + "output_size": [], + "output_padding": [] + }] + + ops_config = [{ + "op_type": "conv2d_transpose", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["conv2d_weight"] + }, + "op_outputs": { + "Output": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial( + generate_weight1, dics)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial( + generate_input1, batch, dics)) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "input_data": [1, attrs[0]['groups'], 32, 32], + "output_data": [1, attrs[0]['groups'], 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, attrs[0]['groups'], 64, 64], + "output_data": [4, attrs[0]['groups'], 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, attrs[0]['groups'], 64, 64], + "output_data": [1, attrs[0]['groups'], 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs[ + 'padding_algorithm'] == "SAME" or program_config.ops[ + 0].attrs['padding_algorithm'] == "VALID": + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_IMPLEMENTED, + "When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op." + ) + + def teller2(program_config, predictor_config): + if program_config.ops[0].attrs['dilations'][ + 0] != 1 or program_config.ops[0].attrs['dilations'][1] != 1: + return True + return False + + self.add_skip_case( + teller2, SkipReasons.TRT_NOT_IMPLEMENTED, + "When dilations's element is not equal 1, there are different behaviors between Trt and Paddle." + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + def test_quant(self): + self.add_skip_trt_case() + self.run_test(quant=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py new file mode 100644 index 00000000000..3e923b1bd89 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertPool2dTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(attrs: List[Dict[str, Any]]): + return np.ones([1, 3, 64, 64]).astype(np.float32) + + def generate_weight1(attrs: List[Dict[str, Any]]): + return np.random.random([24, 3, 3, 3]).astype(np.float32) + + for strides in [[1, 1], [2, 2], [1, 2]]: + for paddings in [[0, 2], [0, 3], [1, 2, 3, 4]]: + for pooling_type in ['max', 'avg']: + for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']: + for ksize in [[2, 3], [3, 3]]: + for data_format in ['NCHW']: + for global_pooling in [True, False]: + for exclusive in [True, False]: + for adaptive in [True, False]: + for ceil_mode in [True, False]: + self.paddings = paddings + + dics = [{ + "pooling_type": + pooling_type, + "ksize": ksize, + "data_fromat": data_format, + "padding_algorithm": + padding_algotithm, + "paddings": paddings, + "strides": strides, + "data_format": data_format, + "global_pooling": + global_pooling, + "exclusive": exclusive, + "adaptive": adaptive, + "ceil_mode": ceil_mode + }] + + ops_config = [{ + "op_type": "pool2d", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": { + "Out": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config( + ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": + TensorConfig( + data_gen=partial( + generate_input1, + dics)) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if self.paddings == [0, 3] or attrs[0][ + 'global_pooling'] == True or attrs[0]['ceil_mode'] == True: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if len(program_config.ops[0].attrs['paddings']) == 4: + return True + return False + + self.add_skip_case(teller1, SkipReasons.TRT_NOT_IMPLEMENTED, + "4-dims paddings are not support for trt now.") + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab