From 1bd8125f7bfd2aac4270a9a25aee8314cd406c25 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 5 Apr 2022 21:24:31 +0800 Subject: [PATCH] add fake index and unittest for multiclass_nms3 trt (#41344) * add fake index and unittest for multiclass_nms3 trt * modify unittest --- .../tensorrt/convert/multiclass_nms3_op.cc | 13 +- .../test_trt_convert_multiclass_nms3.py | 181 ++++++++++++++++++ .../inference/test_trt_multiclass_nms_op.py | 2 +- 3 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multiclass_nms3.py diff --git a/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc b/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc index 00f1419f08..a968ea2a2c 100644 --- a/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -38,6 +38,7 @@ class MultiClassNMS3OpConverter : public OpConverter { std::string scores = op_desc.Input("Scores").front(); std::string output_name = op_desc.Output("Out").front(); std::string rois_num_name = op_desc.Output("NmsRoisNum").front(); + std::string index_name = op_desc.Output("Index").front(); auto* bboxes_tensor = engine_->GetITensor(bboxes); auto* scores_tensor = engine_->GetITensor(scores); @@ -122,10 +123,20 @@ class MultiClassNMS3OpConverter : public OpConverter { engine_, Concatenation, concat_inputs.data(), concat_inputs.size()); nms_concat_layer->setAxis(1); + // add fake index as output to be consistent with the outputs of + // multiclass_nms3 + std::vector index(1, 0); + auto constant_layer = TRT_ENGINE_ADD_LAYER( + engine_, Constant, nvinfer1::Dims2(1, 1), + nvinfer1::Weights{nvinfer1::DataType::kINT32, + static_cast(index.data()), 1}); + RreplenishLayerAndOutput(batch_nms_layer, "multiclass_nms3", {rois_num_name}, test_mode); RreplenishLayerAndOutput(nms_concat_layer, "multiclass_nms3", {output_name}, test_mode); + RreplenishLayerAndOutput(constant_layer, "multiclass_nms3", {index_name}, + test_mode); } }; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multiclass_nms3.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multiclass_nms3.py new file mode 100644 index 0000000000..b6a3f0c9cb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multiclass_nms3.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def create_inference_config(self, use_trt=True) -> paddle_infer.Config: + if use_trt: + config = paddle_infer.Config() + config.disable_glog_info() + config.enable_use_gpu(100, 0) + config.set_optim_cache_dir(self.cache_dir) + config.switch_ir_debug() + config.enable_tensorrt_engine( + max_batch_size=self.trt_param.max_batch_size, + workspace_size=self.trt_param.workspace_size, + min_subgraph_size=self.trt_param.min_subgraph_size, + precision_mode=self.trt_param.precision, + use_static=self.trt_param.use_static, + use_calib_mode=self.trt_param.use_calib_mode) + if len(self.dynamic_shape.min_input_shape + ) != 0 and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.max_input_shape.keys( + ) and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.opt_input_shape.keys(): + config.set_trt_dynamic_shape_info( + self.dynamic_shape.min_input_shape, + self.dynamic_shape.max_input_shape, + self.dynamic_shape.opt_input_shape, + self.dynamic_shape.disable_trt_plugin_fp16) + return config + else: + config = paddle_infer.Config() + config.switch_ir_debug(True) + config.set_optim_cache_dir(self.cache_dir) + config.disable_glog_info() + return config + + def sample_program_configs(self): + def generate_boxes(batch, num_boxes): + return np.arange( + batch * num_boxes * 4, + dtype=np.float32).reshape([batch, num_boxes, 4]) + + def generate_scores(batch, num_boxes, num_classes): + return np.arange( + batch * num_classes * num_boxes, + dtype=np.float32).reshape([batch, num_classes, num_boxes]) + # return np.random.rand(batch, num_classes, num_boxes).astype(np.float32) + + for batch in [1, 2]: + for num_boxes in [4, 12]: + for num_classes in [2, 6]: + for score_threshold in [0.01, ]: + ops_config = [{ + "op_type": "multiclass_nms3", + "op_inputs": { + "BBoxes": ["input_bboxes"], + "Scores": ["input_scores"], + }, + "op_outputs": { + "Out": ["nms_output_boxes"], + "Index": ["nms_output_index"], + "NmsRoisNum": ["nms_output_num"] + }, + "op_attrs": { + "background_label": -1, + "score_threshold": score_threshold, + "nms_top_k": num_boxes, + "keep_top_k": num_boxes, + "nms_threshold": 0.3, + "normalized": False, + "nms_eta": 1.1 + } + }] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_bboxes": TensorConfig(data_gen=partial( + generate_boxes, batch, num_boxes)), + "input_scores": TensorConfig( + data_gen=partial(generate_scores, batch, + num_boxes, num_classes)) + }, + outputs=[ + "nms_output_boxes", "nms_output_num", + "nms_output_index" + ]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-2 + + def assert_tensors_near(self, + atol: float, + rtol: float, + tensor: Dict[str, np.array], + baseline: Dict[str, np.array]): + # the order of tensorrt outputs are not consistent with paddle + for key, arr in tensor.items(): + if key == "nms_output_index": + continue + if key == "nms_output_boxes": + basline_arr = np.array( + sorted( + baseline[key].reshape((-1, 6)), + key=lambda i: [i[0], i[1]])) + arr = np.array( + sorted( + arr.reshape((-1, 6)), key=lambda i: [i[0], i[1]])) + else: + basline_arr = np.array(baseline[key].reshape((-1, 1))) + arr = np.array(arr.reshape((-1, 1))) + + self.assertTrue( + basline_arr.shape == arr.shape, + "The output shapes are not equal, the baseline shape is " + + str(basline_arr.shape) + ', but got ' + str(arr.shape)) + diff = abs(basline_arr - arr) + self.assertTrue( + np.allclose( + basline_arr, arr, atol=atol, rtol=rtol), + "Output has diff, Maximum absolute error: {}".format( + np.amax(diff))) + + def assert_op_size(self, trt_engine_num, paddle_op_num): + # tensorrt op num is not consistent with paddle + return True + + def test(self): + self.trt_param.workspace_size = 1 << 20 + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py index 3ca6985985..045261fabb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- GitLab