From 64adfe7a16929c92536be5c0e0699a7bc8db053d Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Fri, 28 Apr 2023 12:26:17 +0800 Subject: [PATCH] [inference][trt]trt support 0 dims (#53383) * trt support 0 dim * trt support 0 dim * update activation ut --- .../inference/tensorrt/convert/op_converter.h | 16 +- paddle/fluid/inference/tensorrt/op_teller.cc | 5 +- .../operators/tensorrt/tensorrt_engine_op.h | 36 +---- .../inference/test_trt_convert_activation.py | 20 ++- test/ir/inference/test_trt_convert_celu.py | 142 ------------------ .../inference/test_trt_convert_logsigmoid.py | 139 ----------------- test/ir/inference/test_trt_convert_silu.py | 139 ----------------- .../inference/test_trt_convert_tanhshrink.py | 139 ----------------- 8 files changed, 34 insertions(+), 602 deletions(-) delete mode 100644 test/ir/inference/test_trt_convert_celu.py delete mode 100755 test/ir/inference/test_trt_convert_logsigmoid.py delete mode 100755 test/ir/inference/test_trt_convert_silu.py delete mode 100755 test/ir/inference/test_trt_convert_tanhshrink.py diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 87ad887cef3..fcd6146a5f4 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -306,13 +306,15 @@ class OpConverter { auto max_input_shape = engine->max_input_shape()[input]; auto optim_input_shape = engine->optim_input_shape()[input]; size_t ranks = min_input_shape.size(); - if (ranks == 0) { - all_dynamic_shape_set = false; - LOG(INFO) << "trt input [" << input.c_str() - << "] dynamic shape info not set, please check and retry."; - // check other input - continue; - } + // allow 0 dim for dynamic shape input + // if (ranks == 0) { + // all_dynamic_shape_set = false; + // LOG(INFO) << "trt input [" << input.c_str() + // << "] dynamic shape info not set, please check and + // retry."; + // // check other input + // continue; + // } std::vector input_shape; // input_shape.push_back(-1); for (size_t i = 0; i < ranks; i++) { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6a94e14d7e6..6aff2fb3bf6 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -116,9 +116,10 @@ struct SimpleOpTypeSetTeller : public Teller { auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); - if (x_shape.size() == 1) { + if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) { VLOG(3) << op_type - << " op does not support input's dim is 1 in tensorrt."; + << " op does not support input's dim is 1 or 0 in tensorrt " + "static shape mode."; return false; } #if !IS_TRT_VERSION_GE(7000) diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 2f6dee8b52d..d942bb0bf5d 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -39,6 +39,7 @@ #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/utils/io_utils.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace inference { @@ -64,19 +65,10 @@ using inference::tensorrt::TRTInt8Calibrator; static void RuntimeStaticShapeCheck(std::vector runtime_input_shape, std::vector model_input_shape) { - auto comma_fold = [](std::string a, int b) { - return std::move(a) + ", " + std::to_string(b); - }; std::string model_input_shape_str = - std::accumulate(std::next(model_input_shape.begin()), - model_input_shape.end(), - std::to_string(model_input_shape[0]), - comma_fold); + string::join_strings(model_input_shape, ','); std::string runtime_input_shape_str = - std::accumulate(std::next(runtime_input_shape.begin()), - runtime_input_shape.end(), - std::to_string(runtime_input_shape[0]), - comma_fold); + string::join_strings(runtime_input_shape, ','); PADDLE_ENFORCE_EQ( model_input_shape == runtime_input_shape, true, @@ -137,24 +129,10 @@ static void RuntimeDynamicShapeCheck( } return true; }; - auto comma_fold = [](std::string a, int b) { - return std::move(a) + ", " + std::to_string(b); - }; std::string runtime_input_shape_str = - std::accumulate(std::next(runtime_input_shape.begin()), - runtime_input_shape.end(), - std::to_string(runtime_input_shape[0]), - comma_fold); - std::string min_input_shape_str = - std::accumulate(std::next(min_input_shape.begin()), - min_input_shape.end(), - std::to_string(min_input_shape[0]), - comma_fold); - std::string max_input_shape_str = - std::accumulate(std::next(max_input_shape.begin()), - max_input_shape.end(), - std::to_string(max_input_shape[0]), - comma_fold); + string::join_strings(runtime_input_shape, ','); + std::string min_input_shape_str = string::join_strings(min_input_shape, ','); + std::string max_input_shape_str = string::join_strings(max_input_shape, ','); PADDLE_ENFORCE_EQ(is_input_shape_valid( runtime_input_shape, min_input_shape, max_input_shape), true, @@ -551,7 +529,6 @@ class TensorRTEngineOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); auto stream = reinterpret_cast(dev_ctx).stream(); - std::vector output_maps = Attr>("output_name_mapping"); @@ -566,7 +543,6 @@ class TensorRTEngineOp : public framework::OperatorBase { trt_context = engine->context(); binding_offset = engine->GetBindingsOffset(); } - // Bind input tensor to TRT. for (const auto &x : runtime_input_names_) { #if IS_TRT_VERSION_LT(8000) diff --git a/test/ir/inference/test_trt_convert_activation.py b/test/ir/inference/test_trt_convert_activation.py index 85583c37f64..bba6beae142 100644 --- a/test/ir/inference/test_trt_convert_activation.py +++ b/test/ir/inference/test_trt_convert_activation.py @@ -33,7 +33,9 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): def sample_program_configs(self): def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): - if dims == 1: + if dims == 0: + return np.random.random([]).astype(np.float32) + elif dims == 1: return np.random.random([32]).astype(np.float32) elif dims == 2: return np.random.random([3, 32]).astype(np.float32) @@ -42,7 +44,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): else: return np.random.random([batch, 3, 32, 32]).astype(np.float32) - for dims in [1, 2, 3, 4]: + for dims in [0, 1, 2, 3, 4]: for batch in [1, 4]: for op_type in [ "relu", @@ -51,9 +53,13 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): "relu6", "elu", "selu", + "silu", "softsign", "stanh", "thresholded_relu", + "celu", + "logsigmoid", + "tanh_shrink", "softplus", ]: # few samples to reduce time @@ -63,6 +69,8 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): for alpha in [0.67]: self.dims = dims dics = [{}] + if op_type == "celu": + dics = [{"alpha": 1.0}] if op_type == "elu": dics = [{"alpha": alpha}] if op_type == "selu": @@ -103,7 +111,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): self, program_config ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): - if self.dims == 1: + if self.dims == 0: + self.dynamic_shape.min_input_shape = {"input_data": []} + self.dynamic_shape.max_input_shape = {"input_data": []} + self.dynamic_shape.opt_input_shape = {"input_data": []} + elif self.dims == 1: self.dynamic_shape.min_input_shape = {"input_data": [1]} self.dynamic_shape.max_input_shape = {"input_data": [64]} self.dynamic_shape.opt_input_shape = {"input_data": [32]} @@ -132,7 +144,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - if self.dims == 1: + if not dynamic_shape and (self.dims == 1 or self.dims == 0): return 0, 3 return 1, 2 diff --git a/test/ir/inference/test_trt_convert_celu.py b/test/ir/inference/test_trt_convert_celu.py deleted file mode 100644 index e8bba5bf0ed..00000000000 --- a/test/ir/inference/test_trt_convert_celu.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import partial -from typing import Any, Dict, List - -import numpy as np -from program_config import ProgramConfig, TensorConfig -from trt_layer_auto_scan_test import TrtLayerAutoScanTest - -import paddle.inference as paddle_infer - - -class TrtConvertCeluTest(TrtLayerAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: - return True - - def sample_program_configs(self): - def generate_input1(dims, attrs: List[Dict[str, Any]]): - if dims == 1: - return np.ones([3]).astype(np.float32) - elif dims == 2: - return np.ones([3, 64]).astype(np.float32) - elif dims == 3: - return np.ones([3, 64, 64]).astype(np.float32) - else: - return np.ones([1, 3, 64, 64]).astype(np.float32) - - for dims in [1, 2, 3, 4]: - for alpha in [1.0, 2.0, 3.0]: - self.dims = dims - - dics = [{"alpha": alpha}] - - ops_config = [ - { - "op_type": "celu", - "op_inputs": { - "X": ["input_data"], - }, - "op_outputs": {"Out": ["output_data"]}, - "op_attrs": dics[0], - } - ] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig( - data_gen=partial(generate_input1, dims, dics) - ) - }, - outputs=["output_data"], - ) - - yield program_config - - def sample_predictor_configs( - self, program_config - ) -> (paddle_infer.Config, List[int], float): - def generate_dynamic_shape(attrs): - if self.dims == 1: - self.dynamic_shape.min_input_shape = {"input_data": [1]} - self.dynamic_shape.max_input_shape = {"input_data": [128]} - self.dynamic_shape.opt_input_shape = {"input_data": [64]} - elif self.dims == 2: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} - self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} - elif self.dims == 3: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} - self.dynamic_shape.max_input_shape = { - "input_data": [10, 64, 64] - } - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} - else: - self.dynamic_shape.min_input_shape = { - "input_data": [1, 3, 32, 32] - } - self.dynamic_shape.max_input_shape = { - "input_data": [4, 3, 64, 64] - } - self.dynamic_shape.opt_input_shape = { - "input_data": [1, 3, 64, 64] - } - - def clear_dynamic_shape(): - self.dynamic_shape.min_input_shape = {} - self.dynamic_shape.max_input_shape = {} - self.dynamic_shape.opt_input_shape = {} - - def generate_trt_nodes_num(attrs, dynamic_shape): - if self.dims == 1: - return 0, 3 - return 1, 2 - - attrs = [ - program_config.ops[i].attrs for i in range(len(program_config.ops)) - ] - - # for static_shape - clear_dynamic_shape() - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), (1e-3, 1e-3) - - # for dynamic_shape - generate_dynamic_shape(attrs) - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), (1e-3, 1e-3) - - def test(self): - self.run_test() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/ir/inference/test_trt_convert_logsigmoid.py b/test/ir/inference/test_trt_convert_logsigmoid.py deleted file mode 100755 index b28a02a94f7..00000000000 --- a/test/ir/inference/test_trt_convert_logsigmoid.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import partial -from typing import Any, Dict, List - -import numpy as np -from program_config import ProgramConfig, TensorConfig -from trt_layer_auto_scan_test import TrtLayerAutoScanTest - -import paddle.inference as paddle_infer - - -class TrtConvertLogSigmoidTest(TrtLayerAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: - return True - - def sample_program_configs(self): - def generate_input1(dims, attrs: List[Dict[str, Any]]): - if dims == 1: - return np.ones([3]).astype(np.float32) - elif dims == 2: - return np.ones([3, 64]).astype(np.float32) - elif dims == 3: - return np.ones([3, 64, 64]).astype(np.float32) - else: - return np.ones([1, 3, 64, 64]).astype(np.float32) - - for dims in [1, 2, 3, 4]: - self.dims = dims - - ops_config = [ - { - "op_type": "logsigmoid", - "op_inputs": { - "X": ["input_data"], - }, - "op_outputs": {"Out": ["output_data"]}, - "op_attrs": {}, - } - ] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig( - data_gen=partial(generate_input1, dims, {}) - ) - }, - outputs=["output_data"], - ) - - yield program_config - - def sample_predictor_configs( - self, program_config - ) -> (paddle_infer.Config, List[int], float): - def generate_dynamic_shape(attrs): - if self.dims == 1: - self.dynamic_shape.min_input_shape = {"input_data": [1]} - self.dynamic_shape.max_input_shape = {"input_data": [128]} - self.dynamic_shape.opt_input_shape = {"input_data": [64]} - elif self.dims == 2: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} - self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} - elif self.dims == 3: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} - self.dynamic_shape.max_input_shape = { - "input_data": [10, 64, 64] - } - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} - else: - self.dynamic_shape.min_input_shape = { - "input_data": [1, 3, 32, 32] - } - self.dynamic_shape.max_input_shape = { - "input_data": [4, 3, 64, 64] - } - self.dynamic_shape.opt_input_shape = { - "input_data": [1, 3, 64, 64] - } - - def clear_dynamic_shape(): - self.dynamic_shape.min_input_shape = {} - self.dynamic_shape.max_input_shape = {} - self.dynamic_shape.opt_input_shape = {} - - def generate_trt_nodes_num(attrs, dynamic_shape): - if self.dims == 1: - return 0, 3 - return 1, 2 - - attrs = [ - program_config.ops[i].attrs for i in range(len(program_config.ops)) - ] - - # for static_shape - clear_dynamic_shape() - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), (1e-3, 1e-3) - - # for dynamic_shape - generate_dynamic_shape(attrs) - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), (1e-3, 1e-3) - - def test(self): - self.run_test() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/ir/inference/test_trt_convert_silu.py b/test/ir/inference/test_trt_convert_silu.py deleted file mode 100755 index 052d0bfccf7..00000000000 --- a/test/ir/inference/test_trt_convert_silu.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import partial -from typing import Any, Dict, List - -import numpy as np -from program_config import ProgramConfig, TensorConfig -from trt_layer_auto_scan_test import TrtLayerAutoScanTest - -import paddle.inference as paddle_infer - - -class TrtConvertSiluTest(TrtLayerAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: - return True - - def sample_program_configs(self): - def generate_input1(dims, attrs: List[Dict[str, Any]]): - if dims == 1: - return np.ones([3]).astype(np.float32) - elif dims == 2: - return np.ones([3, 64]).astype(np.float32) - elif dims == 3: - return np.ones([3, 64, 64]).astype(np.float32) - else: - return np.ones([1, 3, 64, 64]).astype(np.float32) - - for dims in [1, 2, 3, 4]: - self.dims = dims - - ops_config = [ - { - "op_type": "silu", - "op_inputs": { - "X": ["input_data"], - }, - "op_outputs": {"Out": ["output_data"]}, - "op_attrs": {}, - } - ] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig( - data_gen=partial(generate_input1, dims, {}) - ) - }, - outputs=["output_data"], - ) - - yield program_config - - def sample_predictor_configs( - self, program_config - ) -> (paddle_infer.Config, List[int], float): - def generate_dynamic_shape(attrs): - if self.dims == 1: - self.dynamic_shape.min_input_shape = {"input_data": [1]} - self.dynamic_shape.max_input_shape = {"input_data": [128]} - self.dynamic_shape.opt_input_shape = {"input_data": [64]} - elif self.dims == 2: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} - self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} - elif self.dims == 3: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} - self.dynamic_shape.max_input_shape = { - "input_data": [10, 64, 64] - } - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} - else: - self.dynamic_shape.min_input_shape = { - "input_data": [1, 3, 32, 32] - } - self.dynamic_shape.max_input_shape = { - "input_data": [4, 3, 64, 64] - } - self.dynamic_shape.opt_input_shape = { - "input_data": [1, 3, 64, 64] - } - - def clear_dynamic_shape(): - self.dynamic_shape.min_input_shape = {} - self.dynamic_shape.max_input_shape = {} - self.dynamic_shape.opt_input_shape = {} - - def generate_trt_nodes_num(attrs, dynamic_shape): - if self.dims == 1: - return 0, 3 - return 1, 2 - - attrs = [ - program_config.ops[i].attrs for i in range(len(program_config.ops)) - ] - - # for static_shape - clear_dynamic_shape() - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), (1e-3, 1e-3) - - # for dynamic_shape - generate_dynamic_shape(attrs) - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), (1e-3, 1e-3) - - def test(self): - self.run_test() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/ir/inference/test_trt_convert_tanhshrink.py b/test/ir/inference/test_trt_convert_tanhshrink.py deleted file mode 100755 index 71bb4443794..00000000000 --- a/test/ir/inference/test_trt_convert_tanhshrink.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import partial -from typing import Any, Dict, List - -import numpy as np -from program_config import ProgramConfig, TensorConfig -from trt_layer_auto_scan_test import TrtLayerAutoScanTest - -import paddle.inference as paddle_infer - - -class TrtConvertTanhshrinkTest(TrtLayerAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: - return True - - def sample_program_configs(self): - def generate_input1(dims, attrs: List[Dict[str, Any]]): - if dims == 1: - return np.ones([3]).astype(np.float32) - elif dims == 2: - return np.ones([3, 64]).astype(np.float32) - elif dims == 3: - return np.ones([3, 64, 64]).astype(np.float32) - else: - return np.ones([1, 3, 64, 64]).astype(np.float32) - - for dims in [1, 2, 3, 4]: - self.dims = dims - - ops_config = [ - { - "op_type": "tanh_shrink", - "op_inputs": { - "X": ["input_data"], - }, - "op_outputs": {"Out": ["output_data"]}, - "op_attrs": {}, - } - ] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig( - data_gen=partial(generate_input1, dims, {}) - ) - }, - outputs=["output_data"], - ) - - yield program_config - - def sample_predictor_configs( - self, program_config - ) -> (paddle_infer.Config, List[int], float): - def generate_dynamic_shape(attrs): - if self.dims == 1: - self.dynamic_shape.min_input_shape = {"input_data": [1]} - self.dynamic_shape.max_input_shape = {"input_data": [128]} - self.dynamic_shape.opt_input_shape = {"input_data": [64]} - elif self.dims == 2: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} - self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} - elif self.dims == 3: - self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} - self.dynamic_shape.max_input_shape = { - "input_data": [10, 64, 64] - } - self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} - else: - self.dynamic_shape.min_input_shape = { - "input_data": [1, 3, 32, 32] - } - self.dynamic_shape.max_input_shape = { - "input_data": [4, 3, 64, 64] - } - self.dynamic_shape.opt_input_shape = { - "input_data": [1, 3, 64, 64] - } - - def clear_dynamic_shape(): - self.dynamic_shape.min_input_shape = {} - self.dynamic_shape.max_input_shape = {} - self.dynamic_shape.opt_input_shape = {} - - def generate_trt_nodes_num(attrs, dynamic_shape): - if self.dims == 1: - return 0, 3 - return 1, 2 - - attrs = [ - program_config.ops[i].attrs for i in range(len(program_config.ops)) - ] - - # for static_shape - clear_dynamic_shape() - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False - ), (1e-3, 1e-3) - - # for dynamic_shape - generate_dynamic_shape(attrs) - self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), 1e-5 - self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True - ), (1e-3, 1e-3) - - def test(self): - self.run_test() - - -if __name__ == "__main__": - unittest.main() -- GitLab